{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 量子分类器\n",
    "\n",
    "<em> Copyright (c) 2021 Institute for Quantum Computing, Baidu Inc. All Rights Reserved. </em>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概览\n",
    "\n",
    "本教程我们将讨论量子分类器（quantum classifier）的原理，以及如何利用量子神经网络（quantum neural network, QNN）来完成**两分类**任务。这类方法早期工作的主要代表是 Mitarai et al.(2018) 的量子电路学习 [(Quantum Circuit Learning, QCL)](https://arxiv.org/abs/1803.00745) [1], Farhi & Neven (2018) [2] 和 Schuld et al.(2018) 的中心电路量子分类器 [Circuit-Centric Quantum Classifiers](https://arxiv.org/abs/1804.00633) [3]。这里我们以第一类的 QCL 框架应用于监督学习（Supervised learning）为例进行介绍，通常我们需要先将经典数据编码成量子数据，然后通过训练量子神经网络的参数，最终得到一个最优的分类器。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 背景\n",
    "\n",
    "在监督学习的情况下，我们需要输入 $N$ 组带标签的数据点构成的数据集 $D = \\{(x^k,y^k)\\}_{k=1}^{N}$，其中 $x^k\\in \\mathbb{R}^{m}$ 是数据点，$y^k \\in\\{0,1\\}$ 是对应数据点 $x^k$ 的分类标签。**分类过程实质上是一个决策过程，决策给定数据点的标签归属问题**。 对于量子分类器框架，分类器 $\\mathcal{F}$ 的实现方式为一个含参 $\\theta$ 的量子神经网络/参数化量子电路, 测量量子系统以及数据后处理的组合。一个优秀的分类器 $\\mathcal{F}_\\theta$ 应该尽可能的将每个数据集内的数据点正确地映射到相对应的标签上 $\\mathcal{F}_\\theta(x^k) \\rightarrow y^k$。因此，我们将预测标签 $\\tilde{y}^{k} = \\mathcal{F}_\\theta(x^k)$ 和实际标签 $y^k$ 之间的累计距离作为损失函数 $\\mathcal{L}(\\theta)$ 进行优化。对于两分类任务，可以选择二次损失函数\n",
    "\n",
    "$$\n",
    "\\mathcal{L}(\\theta) = \\sum_{k=1}^N |\\tilde{y}^{k}-y^k|^2. \\tag{1}\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 方案流程\n",
    "\n",
    "这里我们给出实现量子电路学习 (QCL) 框架下量子分类器的一个流程。\n",
    "\n",
    "1. 在初始化的量子比特 $\\lvert 0 \\rangle$ 上作用参数化的酉门 $U$（unitary gate），从而把原始的经典数据点 $x^k$ 编码成量子计算机可以运行的量子数据 $\\lvert \\psi_{in}\\rangle^k$。\n",
    "2. 使输入态 $\\lvert \\psi_{in} \\rangle^k$ 通过参数为 $\\theta$ 的参数化电路 $U(\\theta)$ ，由此获得输出态 $\\lvert \\psi_{out}\\rangle^k = U(\\theta)\\lvert \\psi_{in} \\rangle^k$。\n",
    "3. 对量子神经网络处理后的量子态 $\\lvert \\psi_{out}\\rangle^k$ 进行测量和数据后处理，得到估计出的标签 $\\tilde{y}^{k}$。\n",
    "4. 重复步骤2-3直到数据集内所有的数据点都经过了处理。然后计算损失函数 $\\mathcal{L}(\\theta)$。\n",
    "5. 通过梯度下降等优化方法不断调整参数 $\\theta$ 的值，从而最小化损失函数。记录优化完成后的最优参数 $\\theta^*$, 这时我们就学习到了最优的分类器 $\\mathcal{F}_{\\theta^*}$。\n",
    "\n",
    "\n",
    "\n",
    "![QCL](figures/qclassifier-fig-pipeline-cn.png \"图 1：量子分类器训练的流程图\")\n",
    "<div style=\"text-align:center\">图 1：量子分类器训练的流程图 </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Paddle Quantum 实现\n",
    "\n",
    "这里，我们先导入所需要的语言包：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:03.419838Z",
     "start_time": "2021-03-02T09:15:03.413324Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import paddle\n",
    "from numpy import pi as PI\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from paddle import matmul, transpose\n",
    "from paddle_quantum.circuit import UAnsatz\n",
    "from paddle_quantum.utils import pauli_str_to_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:03.845958Z",
     "start_time": "2021-03-02T09:15:03.840512Z"
    }
   },
   "outputs": [],
   "source": [
    "# 这是教程中会用到的几个主要函数\n",
    "__all__ = [\n",
    "    \"circle_data_point_generator\",\n",
    "    \"data_point_plot\",\n",
    "    \"heatmap_plot\",\n",
    "    \"Ry\",\n",
    "    \"Rz\",\n",
    "    \"Observable\",\n",
    "    \"U_theta\",\n",
    "    \"Net\",\n",
    "    \"QC\",\n",
    "    \"main\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集的生成\n",
    "\n",
    "对于监督学习来说，我们绕不开的一个问题就是——采用的数据集是什么样的？在这个教程中我们按照论文 [1] 里所提及方法生成简单的圆形决策边界二分数据集 $\\{(x^{k}, y^{k})\\}$。其中数据点 $x^{k}\\in \\mathbb{R}^{2}$，标签 $y^{k} \\in \\{0,1\\}$。\n",
    "\n",
    "![数据集](figures/qclassifier-fig-data-cn.png \"图 2：生成的数据集和对应的决策边界\")\n",
    "<div style=\"text-align:center\">图 2：生成的数据集和对应的决策边界 </div>\n",
    "\n",
    "具体的生成方式和可视化请见如下代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:04.631031Z",
     "start_time": "2021-03-02T09:15:04.617301Z"
    }
   },
   "outputs": [],
   "source": [
    "# 圆形决策边界两分类数据集生成器\n",
    "def circle_data_point_generator(Ntrain, Ntest, boundary_gap, seed_data):\n",
    "    \"\"\"\n",
    "    :param Ntrain: 训练集大小\n",
    "    :param Ntest: 测试集大小\n",
    "    :param boundary_gap: 取值于 (0, 0.5), 两类别之间的差距\n",
    "    :param seed_data: 随机种子\n",
    "    :return: 'Ntrain' 训练集\n",
    "             'Ntest' 测试集\n",
    "    \"\"\"\n",
    "    train_x, train_y = [], []\n",
    "    num_samples, seed_para = 0, 0\n",
    "    while num_samples < Ntrain + Ntest:\n",
    "        np.random.seed((seed_data + 10) * 1000 + seed_para + num_samples)\n",
    "        data_point = np.random.rand(2) * 2 - 1\n",
    "\n",
    "        # 如果数据点的模小于(0.7 - gap)，标为0\n",
    "        if np.linalg.norm(data_point) < 0.7 - boundary_gap / 2:\n",
    "            train_x.append(data_point)\n",
    "            train_y.append(0.)\n",
    "            num_samples += 1\n",
    "\n",
    "        # 如果数据点的模大于(0.7 + gap)，标为1\n",
    "        elif np.linalg.norm(data_point) > 0.7 + boundary_gap / 2:\n",
    "            train_x.append(data_point)\n",
    "            train_y.append(1.)\n",
    "            num_samples += 1\n",
    "        else:\n",
    "            seed_para += 1\n",
    "\n",
    "    train_x = np.array(train_x).astype(\"float64\")\n",
    "    train_y = np.array([train_y]).astype(\"float64\").T\n",
    "\n",
    "    print(\"训练集的维度大小 x {} 和 y {}\".format(np.shape(train_x[0:Ntrain]), np.shape(train_y[0:Ntrain])))\n",
    "    print(\"测试集的维度大小 x {} 和 y {}\".format(np.shape(train_x[Ntrain:]), np.shape(train_y[Ntrain:])), \"\\n\")\n",
    "\n",
    "    return train_x[0:Ntrain], train_y[0:Ntrain], train_x[Ntrain:], train_y[Ntrain:]\n",
    "\n",
    "\n",
    "# 用以可视化生成的数据集\n",
    "def data_point_plot(data, label):\n",
    "    \"\"\"\n",
    "    :param data: 形状为 [M, 2], 代表 M 2-D 数据点\n",
    "    :param label: 取值 0 或者 1\n",
    "    :return: 画这些数据点\n",
    "    \"\"\"\n",
    "    dim_samples, dim_useless = np.shape(data)\n",
    "    plt.figure(1)\n",
    "    for i in range(dim_samples):\n",
    "        if label[i] == 0:\n",
    "            plt.plot(data[i][0], data[i][1], color=\"r\", marker=\"o\")\n",
    "        elif label[i] == 1:\n",
    "            plt.plot(data[i][0], data[i][1], color=\"b\", marker=\"o\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.422981Z",
     "start_time": "2021-03-02T09:15:05.043595Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集的维度大小 x (200, 2) 和 y (200, 1)\n",
      "测试集的维度大小 x (100, 2) 和 y (100, 1) \n",
      "\n",
      "训练集 200 个数据点的可视化：\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAApn0lEQVR4nO2df6xexXnnP4+vYydu1Ma89mYdwNeQZZuQ3QrCVZSkUpqkJCHuClOVtiSG3qRUXm6b7kpRqxhZ6kZsrdL+Q6gSRCxKcLlXCQlVFLclYoGEXWk3kFx2AfNDxsYJYErCxQ6RIigJZvaPc974+PX5/XPOe74fafSeM+fX886ZM8/MPM/MmHMOIYQQw2VV1wIIIYToFikCIYQYOFIEQggxcKQIhBBi4EgRCCHEwFndtQBl2LBhg9uyZUvXYgghRK944IEHXnDObZyM76Ui2LJlC8vLy12LIYQQvcLMnoqLV9eQEEIMHCkCIYQYOFIEQggxcKQIhBBi4EgRCCHEwKlFEZjZzWb2vJk9knDczOxvzeyQmT1sZu+MHJs3s4NhmK9DHiHKsLQEW7bAqlXB79JS1xL5ic/p5JNsPsmSiXOucgDeB7wTeCTh+Fbgm4AB7wbuD+NPAw6Hv+vD7fVZz7vgggtcn1lcdG521jmz4HdxsWuJxOKic+vWOQcnwrp1ejeTNJVOdXwTPr1Dn2SJAiy7uDI6LrJMALakKIIvAh+L7B8ANgEfA76YdF5S6LMiiMsgZs4tLHQtWbP4rvxmZ09+J+MwO9u1ZH7RRDrVVWj69A59kiVKkiJoy0ZwOvBMZP9IGJcUfwpmtsPMls1seWVlpTFBm2bXLnjppZPjnIMbb6y/6ehL03RpCXbsgKeeCv7rU08F+z41lZ9+ulj8UHkqdjhStXSK+yZeeimIL4JP79AnWfLQG2Oxc26Pc27OOTe3ceMpI6R7Q9KH5FzxjJ+GT4VvXR96k2zeXCy+SXxR4JMsLYFZ/LEy6TT+n3UpF5/eoU+y5KEtRfAscGZk/4wwLil+akn6kKDe2oJPhW8fake7d8O6dSfHrVsXxLeJTwp8kl27ApkmMSueTtH/mUTRQrPKO6yifOOu9SU/5Sauv6hMIN1G8FucbCz+rjthLP4+gaF4fbh9Wtaz6rARdNVnHddvGNd/WFU+s/hnmNX4Z3Lia3/pJD7YMZpMq6byFBSXJel/VjWsLi46NxqduM9olH2fKjaKtGt9yE+T0KSxGPgy8Bzwc4J+/iuBq4CrwuMGfAF4EtgPzEWu/UPgUBg+med5VRVBlxb9tMw/fn4d8vlU+PrqQeEjTSlw3/JUmlKpUmiW+Z9V/pdP31keGlUEbYeqiiDp5c3MNK+9o7WVaPilX8qWr0jm6qrwTaoF+Vg78pGmChbf8pRP/zOP8k3Kv1Wu7QIpgghptZGmC83FRefWrDn5WWvWnPysumqFbWdA1fyr01Qa+panfPqfWcojTdYq13aBFEGErP7Jppt3WR9T35qbY/oqt280ocB9fDe+/M+swjrtnlWu7QIpgghxL6+O2lJd+FaLyItPBmpxMnXnKR+6O+JkKPs/0/5PVr6ucm3bSBFMsLgY2AS6aBHkla/rD60ovtV+xMn43q1Tlwx1fzvTZEyWIoghy1bQh8LXJ3woIETz+FC4tSlDU+6lXSBFEEOarWA0quURXtFGK6OPLRlRDB+6O9qWoUq+9umbkCKIIc6DB5x73eumrwDzrWYi+svQWgRJ+FTA5yVJEfRmrqEm2L4dbr4ZRqMTcaMRfOlLwbFpwqcpJ0S/8WH6hK5l8HkqkDJYoCT6xdzcnFteXu5ajF6xalWQYScxg9dea18e0W+WloJKxNNPB3MC7d7dfuWpSxmSJsubnYUf/KAdGcpgZg845+ZOiZciGAZ9zbhC+EhfK1ZJimDQXUNp+DoVcFm6bkoLMU1UmWbax7JFiiCGNvr/imSGOjLO9u2wZ0/QAjALfvfsmT5biBBtULZi5a1tIc6C7HtoeqnKpj0SinjwyNtHiPop4/Ezec3CQvF7dO3thNxH89O0j3KRzNB1xhFi2shTuYor9OuokHU9BiNJEchYHEPThtUihqa+GqWE8JWs73vcfRN1tzaL/w6LlgldO200aiw2s4vM7ICZHTKznTHHrzOzB8PwhJm9GDl2PHJsXx3yVKVpw2oRQ1NZo5SPBikhmqJIfs9aOjVuzE1SfbnocqveOm3ENROKBGCGYOWxs4E1wEPAuSnn/ylwc2T/p0Wf2XTXkHPNjhps2kYgu4IYEkXze1Z3a571Sqp00XY5IpmmbATAe4A7I/tXA1ennP9/gA9F9r1UBE1TJDMUzTiyK4ghUTS/l11DYFJB9LFy1aQiuBS4KbJ/BfD5hHNnCdY2nonEvQosA/cBl+R55jQogibp2iAlRJuUye9plaskRTH2EoITU9j3ZY6hMUmKoO1xBJcBtzvnjkfiZl1gvPg48Dkze2vchWa2w8yWzWx5ZWWlDVl7S5XBLkL0jTL5ffv2wDj72mvBb3Q8TdKYmxtuONHHfzwswbwZB1CROhTBs8CZkf0zwrg4LgO+HI1wzj0b/h4G7gXOj7vQObfHOTfnnJvbuHFjVZmngiQDmbcGKSEaoIn8nqQoyk7e2PYA0sLENROKBGA1cBg4ixPG4nfEnPc24AeE8xuFceuBteH2BuAgKYbmcVDXUHY/Zx+nyBWiLHXk9zz3KNsN5csAUpocUAZsBZ4g8B7aFcZdA1wcOeezwLUT170X2B8qj/3AlXmeJ0UQLJwjg7AQyRR1yMhTAJdxxPBpAGmjiqDtMHRFsLgYn1lkEBYioG6X0rL3dS5/K6KN7zpJEWjSuRR8HZSV1h8pg7AQxfvy0waZRcuBXbtgfr7Y5I15jNnj0cxF71EbcdrB99DWgDJfB2WlDXjxQT4huqZoX35Si2A0ql4OxJUla9YE9x53WyV19bZlI1CLIAGfl3ZMqh2MRppWWggo7lKa5HkE1cuBSXfU0Sgo4o8eDX6feirYTqKN6eKlCBLImo+kS5Iy7fXXdyOPEL5R1KU0aezAsWPx5xctB6LuqG98I/z85/mum51tp3InRZBAWo2ia9uBFpkRIp0y30jc2IEmBmfmVSKtjv2J6y/yPXRpI6hrXnIhhP80YStMs0fEubvWOSYI2QiKkVSjuOMOf20HQoh6aaL1nda1G22RAGzYAJdf3vzSllIEMYy7fq64Iti/9dYTTUWfbQdCiPpJm5eoTDdxknKBE/fasAE++cl4I3ITFU8pggmyFpfWhG5CCKi2EP2kcoGT73X0aLpBue6K5yAVQZoWz3IbrWuCq64NzkKIatTpYh53rzRqr3jGGQ58D1WMxVnGnzwDUaoab3werCaEyEed634UWRWtSlmBFq8PyFo8uo3FpbtewFoIUZ06v+Oke00yGgVG5bLG6kYXr+8TWcbeNubyl8FZiP5TZ1kRd681a4KCf2xQXlyEF15oZrzQ4BRBlrG3jcFaMjgL0X/qLCvi7nXzzUHBH+etVDtx/UW+hyZtBHWSZEuQjUCIYeDbAlFoPYITtPFytIKYEMPGxwpfo4oAuAg4ABwCdsYc/wSwAjwYhj+KHJsnWKLyIDCf53l9WJim6ZWGhBB+U6QMaKtimKQIVlftWjKzGeALwIeAI8D3zGyfc+6xiVNvc859auLa04D/BswBDnggvPbHVeXqGhmEhRg2ecuA8cC08TiC8cA0aG8iyTqMxe8CDjnnDjvnfgZ8BdiW89qPAHc5546Fhf9dBK2L3iODsBDDJm8ZkDUwrY3Bp3UogtOBZyL7R8K4SX7HzB42s9vN7MyC12JmO8xs2cyWV1ZWahC7WdpwQxVC+EveMiBrmcyy01gUoS330X8Etjjnfo2g1r+36A2cc3ucc3POubmNGzfWLmDdaM0AIYbN9u3B+sYzM8H+zEywP1kGpLUc2lopsQ5F8CxwZmT/jDDuFzjnjjrnXgl3bwIuyHttn0mbtVAIMd0sLcHevXD8eLB//HiwP1mbT2s5tGVrrEMRfA84x8zOMrM1wGXAvugJZrYpsnsx8Hi4fSfwYTNbb2brgQ+HcUII0Wvy1ubTeg/asjVW9hpyzr1qZp8iKMBngJudc4+a2TUErkr7gP9iZhcDrwLHCNxJcc4dM7P/TqBMAK5xziWsEiqEEP0hae6guPjt2+N7DHbvPtmjCJqxNQ5u0jkhhGiD1atPdAtFmZmBV1/Nf5+lpaAV8fTTQUtg9+76J52r3CIQQghxKnFKIC0+iaTWQp0MbtI5IYRog9nZYvFdIkUghBAN0KexRFIEYrqpa1im1hYVBenTWCIpghj0zU8JdQ3LbGt4p5g6+jKWSF5DE0xOAAVBc85XTS5SqGstQa0tKqYELVWZk7aGdIsWKDIsM60ZqKlkxZQjRTCBvnkPKdtXl3dYZlbXj6aSFVOOFMEE+uY9o0r/fF63jaxmYJ/cP4QogRTBBPrmPaNKX11et42sZmDd7h/yRhAZtJ5F4pYt8z00vVSl1hP2CLP49f7M6ntGm+uK+riQrfCKJrMICUtVqkUQQ19cvgZBHX11WdWrss3AMtU2eSOIDDrJInHawffQh8XrRU1UrR7lvT6pGZgWX0auNlo4otc0mUVIaBF0XqiXCVIEA6NKX12Vbp+0wr7sfdvshhK9pMkskqQI1DUk/CGpq6VKX10Vf+C0NnrZ+8obQWTQRRaRIhB+0NQ0DlVsDGmFfdn79mkCGtEJnWSRuGZC0QBcBBwADgE7Y45/GngMeBi4B5iNHDsOPBiGfXmep66hKaSp9nAVG0OaTPL+ET2EprqGzGwG+ALwUeBc4GNmdu7Eaf8PmHPO/RpwO/A3kWMvO+fOC8PFVeURPaWpId1VqldpbfS0+9bhBK6xBqJN4rRDkQC8B7gzsn81cHXK+ecD/zuy/9Oiz1SLYArxwYgaZ5Quaqiuo6Wg1oZoCBo0Fp8OPBPZPxLGJXEl8M3I/uvNbNnM7jOzS5IuMrMd4XnLKysrlQQWHtK1ETXJRgHFDNV1OIHX6UiuloXIQ5x2KBKAS4GbIvtXAJ9POPdy4D5gbSTu9PD3bOAHwFuznqkWwZTS9pDu6PNmZuppkdThBJ50DygmS56WhYbRDwqaGkdAzq4h4ELgceDfpNzrFuDSrGdKEQyQMgVWWlfPuHBOKnDLjuKpo4sr6R5m9Y6hUBfU4GhSEawGDgNnAWuAh4B3TJxzPvAkcM5E/Ppx6wDYABwEzs16phRBDylb81xcdG40OrUwyyqw4gq5NWuce93rsgv/pAI8z3+oy0aQpKSKKJSs1okPdhnRKo0pguDebAWeCAv7XWHcNcDF4fbdwI+YcBMF3gvsD5XHfuDKPM+TIugZZQvHuOvyFlhJhVyREJWxyH+oo7uljhZKVkGv6S4GR6OKoO0gRdAz6p6OIU+BlafbJy2sWnVyAd527TnP87IUzsJC/D0WFrr5T6JzkhSBRhaL5ik7RiBuneAoaaN4q64k9NprJ++3vXRdkhfV1q2B948ZXHFF+kjsO+6Iv/c4vmtPLeENUgSiecpOxzAzk3wsq8DavTsoLKsQdddse+m6uAFr8/Owd+8JBencyddMupi2veCO6C9xzQTfg7qGekZZG0Fa102efvesbqXZ2eTuk8muJx88bPLYPaIyq+tHTIC6hkRnlK15zs4mx+eptaZdPx4gdsMNMBrFnxet7ftQe87TDRWVWV0/Ii9x2sH3oBbBQGhzUZqua/t5yGoRFFlwRwwS5DUkeknVgizv9X0oMOMU1tg7qk2Z+5BWIhYpAjFdDLUw6vp/96X1JGJJUgQWHOsXc3Nzbnl5uWsxRFeMJ4iLTsy2bp08XppmaSnwXDp+/NRjs7OBzUV4jZk94Jybm4yXsVj0jzpn5xT5GCvfOCUAzY2nEK0gRSC6p+hUyW0P7ppm8qZ9nPKN4pymue4xUgSiW8qsVdz24C6fqbLeQJG0z6Nk61pnWrRPnOHA9yBj8RRRZtBTnwyWTRp3s9Ih69lF0r7IJH4asOYtyGtIeEnZGTC79p7JQ9MKK60gz/PsImmfNRNskXcnOiNJEchrSHTLli3xk8tNgxdK0/9t1apT5xuCYOTz5s3Zzy4q39JSYCt4+ung/j/9KRw9mv960TnyGhJ+Ms3TIDRt1E6zleR5dtG037795PWbr79+et/dwJAiEN3S5Rw+TS/s3rRRO60gz/PsImkfl1Y+zL8k6iGuv6hoAC4CDgCHgJ0xx9cCt4XH7we2RI5dHcYfAD6S53l12Aj60MUsGqSNhd3bMGonyVjns/tknBep0OCaxTMES1SezYk1i8+dOOePgRvD7cuA28Ltc8Pz1xKsefwkMJP1zKqKQPlatLawe5c1jjqevbjo3MxMelqJ3pCkCCobi83sPcBnnXMfCfevDlsafxU5587wnO+Y2Wrgh8BGYGf03Oh5ac+saiyeZvukyEmaofW115RJIH4qjyjjtBK9oUlj8enAM5H9I2Fc7DnOuVeBnwCjnNcCYGY7zGzZzJZXVlYqCayBqSKzD71KJmna9tAWWaOJhziAb0rpjbHYObfHOTfnnJvbuHFjpXtpYKrI9Jgpm0nKjJT2lTSlJ++gqaIORfAscGZk/4wwLvacsGvoV4CjOa+tnWn2WPQW32rJWR4vZTPJNE2Il6T0ZmaCtAK/3qkoT5zhoEgAVgOHCYy9Y2PxOybO+RNONhZ/Ndx+Bycbiw/TgrHYOXkNtUoZw6sPL2hh4YShdGYm2M+i7EhpH0l7b/K46CU0OcUEsBV4gsDrZ1cYdw1wcbj9euBrBG6i3wXOjly7K7zuAPDRPM/TFBOeklR4F51PyIdCpqwM07ZgfNF3OjOj2pXHNKoI2g5SBB6SVnAWrSX7UJiWlcEXJdZ0ayrpnaqF4DVJikBzDYl6SHO3hGKumFmunW1QRYbJOXl2725vtG1bq7clve9JhuRu2wM015BoljR3y6KGVx/curJkSDN+T87J0+aUC0nG6ssvr9egG/dO45BPdj+Iayb4HtQ15CF5Rurm7a7wpXuliKEUnBuNuu8KyeqyqTMdo+9Uo497AbIRiEapu/D2wWuoqKHUh37xPAvINFE4+6C8RSZSBKJ5fCi82yCr1t1lLTjvAjJNPXsI77/HJCkC2QhEfXTZN94mWbaKLvvFowPlkpiZSb9H0cF/4/OvuCLYv/XW6X7/U4gUgRBFyTKUdj1XyVghJ3H8ePKxolNkTNOUGh7R9kB8KQIhspj8KiGodY9Gp57b5lwlWaVFUqsgrbVQdIqMaZpSwxM60a1x/UW+B9kIRCPE9XFnGUG76hfPu7BOUQNu2uC/uP86TVNqeEKT4ynRgDIhUkgaiPWGN/i5QHve9RKKDm5Luu9oBC+/3J/06TFNjqfUgDIxDMp2riZ1ccQVctD9QKm86yUUNeAnDf6D+PSJHo+er6l8S9PFeEopAjE9VOlcLVqwd20Qbqq0SJqe+9ix+POPHdMC9jXTyTT5cf1FvgfZCEQsVTpXk64djaoNlCpjQ8hzTdsDuHyYCHDKib720SgIdZue0IAyMfVUMVxmTSlRxiBcprBeWDj1fyRd06ahWiOHG6Wt5JUiENNP1Vpr3QVrmXUYkpSZDzVvjRxujLYaXEmKQF5Dwn/yer60NQVzXoq6f6RN7dzmFNyiddqaeb0RryEzO83M7jKzg+Hv+phzzjOz75jZo2b2sJn9fuTYLWb2fTN7MAznVZFHTCFFDMBZ6xC3TVGDbprBumvjtGiUrmder+o1tBO4xzl3DnBPuD/JS8AfOOfeAVwEfM7M3hQ5/ufOufPC8GBFecS0UXTkqk/zHdW1DoOZ3DGnnE48hSJUVQTbgL3h9l7gkskTnHNPOOcOhtv/AjwPbKz4XNF38vr75/WX94Xo/9q1C+bn87dQ4koDM7jqqmA7Kb3anphG1E7njdk4w0HeALwY2bbofsL57wIeB1aF+7cQLFr/MHAdsDbl2h3AMrC8efPmei0ool2KuEj0yW2xDtePotNcyJtnKmjLDk9ZryHgbuCRmLBtsuAHfpxyn01hof/uiTgD1hK0KP4iSx4nr6H+U6Rw71NB15TSSrtvnxSliKXN6axKK4K0EBbsm1ykoE8475eB/wtcmnKv9wP/lOe5UgQ9p6i/f1/cFpuagC3tvpr0rfek6fK660FJiqCqjWAfMB9uzwPfmDzBzNYAXwf+3jl3+8SxTeGvEdgXHqkoj+gDRVwkik6a1iVNuX6k3bdrdxNRmTQzWFuzfFdVBNcCHzKzg8CF4T5mNmdmN4Xn/B7wPuATMW6iS2a2H9gPbAD+sqI8og/kdZGoY2L2Ng2pTbl+pN23SFrKoOwlabq8NV+JuGaC70FdQ1NAnu6eOkYKt21faKobK+2+Wc/sk51lgKS9nrpNQGhksegdScMtITk+St45+6cdpYP3JPWA1j1YXusRxKDWsuekDbCqMrW0r2MQmkLp4D1J4yDH4wuiq6K+4Q31P3+wikBrbveA3buDQn8S5/JZy2RIDVA69J6XXz6xffRo/WXVYBWB1tzuAdu3J3cB5anNdj1uvwnKNGOnMR0GRBtl1WAVgVrLPWF2Nj4+T22283H7NVO2GTtt6TAw2iirBmsslv2sJ/g2tXSXKNMOkjpfu4zFE6i13BNUmz2BmrGDpI2yarCKQOVLj/BpaukukdF3kLRRVg1WEUBy+SK3UuElasZOHXnLmqbrQqvrvV3/meySHtvjYLgVUeEJ4wzYl7mXRCo+lTWDNRYnIXucEKINuihrZCzOiexxQog28KmskSKYQPY4IUQb+FTWSBFMIHucEKINfCprpAgmkFupEKINfCpr5DUUw/btKviFEM3g46J7lVoEZnaamd1lZgfD3/UJ5x2PrE62LxJ/lpndb2aHzOy2cFlLIYToBUXHHPk663HVrqGdwD3OuXOAe8L9OF52zp0Xhosj8X8NXOec+3fAj4ErK8ojhBCtUKZQ93XW46qKYBuwN9zeS7AAfS7CBes/CIwXtC90vRBCdEmZQj3LZbSrWQ2qKoI3O+eeC7d/CLw54bzXm9mymd1nZpeEcSPgRefcq+H+EeD0pAeZ2Y7wHssrKysVxRZCiGqUGQeQ5jLaZbdRpiIws7vN7JGYsC16XrgwctIw5dlwNNvHgc+Z2VuLCuqc2+Ocm3POzW3cuLHo5UIIUStlxgGkuYx22W2UqQiccxc65/5DTPgG8CMz2wQQ/j6fcI9nw9/DwL3A+cBR4E1mNvZcOgN4tvI/EkKIFigzDiDNZbTLkcZVu4b2AfPh9jzwjckTzGy9ma0NtzcAvw48FrYgvg1cmna9EEL4SN5xAJP9/hA/k2iXI42rKoJrgQ+Z2UHgwnAfM5szs5vCc94OLJvZQwQF/7XOucfCY58BPm1mhwhsBn9XUR4hhGiEOENu1vTQRfr9Ox1p7JzrXbjgggtcURYXnZuddc4s+F1cLHyLXPep6zlCiO4Zf88QfNNBcR6Edeuyv+/xtZNhdjb9eU2VH8CyiylTOy/Uy4SiimBxMXhp0Rdh5tzCQqHbxN4nmhmyjgsh+kPc95y3QB8zqTyi5U8XJCmCQaxHkDTvtxncemv+4d1Z84drLQMhpoek7zmKWdAtVPQeXZUJg16PIMnq7lwx16wsq75P84sLIaqR57vNMuQm9ftv3erXcriDUARpL6tIIZ1l1fdpfnEhRDWyvts8htw4z6L5edi716/5hgahCHbvDl5CHEUK6Syrvk/ziwshqhH3PY/LkSJTRk96Ft1xh4fzDcUZDnwPZbyGFhbSrf4LC87NzATxMzPJhmR5DQkxHJr4nrs0IDNkr6ExSS91YSH+xRT1KqoDKRIhpoOkb7moS2mdJCmCQXgNZbF6NRw/fmr8zAy8+uqp8U0xHnwSbTauW6cV0oToG2nfMnT3nQ/aayiLOCWQFt8Uvs5VLoTIJjryeH4++Vv2aYnKMWoR4E+LYNWqoJE4SZavshCiW+JaAHF0/S2rRZDCjh3F4ptC7qdC9JO41nwcvn7LUgTADTfAwkLQAoDgd2EhiB/TxspBcj8VohpdrfCVZzxS2rfcldy/IM6C7Hso6zVUljxzDNXl6SOvISHK0eVcX0meQFGPoCQ52pQbeQ2VJ22+kN275ekjhA90Oa/P0hJcfnn8MZ/mI0qyEUgR5CDNiLt5c/xLHI3ghReal00IEZD0nULwPR47Fnyvu3c3U0nbsAGOHj01PqtAb9NJpBFjsZmdZmZ3mdnB8Hd9zDkfMLMHI+FfxwvYm9ktZvb9yLHzqsjTFGlG3KS+waNHu59ISoghkWaIPXq0+Xl9rr++nI3PByeRqsbincA9zrlzgHvC/ZNwzn3bOXeec+484IPAS8D/iJzy5+PjzrkHK8rTCGlG3LSXJf9/Idoj7juNo6mxOWXHB3jhJBJnOMgbgAPApnB7E3Ag4/wdwFJk/xbg0qLPbdtY7FyyEXdxMdlAFDd3SPQ+o1EQZBgWoh7SvkcfFoZJoi0nEZqYawh4MbJt0f2E878F/KfI/i2hMnkYuA5Ym+e5XSiCNEajZE+BKFkrHjXp4SBvJDEUsjx4xhNLDvEbKK0IgLuBR2LCtsmCH/hxyn02ASvA6ybiDFgL7AX+IuX6HcAysLx58+YWkiw/ed2/8mTQJiae0hKaYkjkWWJyqN9AUy2C3F1DwH8F9qQcfz/wT3me61uLwLl8Ne6k6WebbrJ2OduhEF0w2QW7apW+AeeSFUFVY/E+YD7cnge+kXLux4AvRyPMbFP4a8AlBC2NXjK5+EScgSiPF0ATngJaQlP0gTpH10a/xxdeSHYr1TcQUFURXAt8yMwOAheG+5jZnJndND7JzLYAZwL/c+L6JTPbD+wHNgB/WVEer8nyamjKU8AH9zQh0hhP2tbU8o36BjKIayb4HnzsGspLF15DshEI36mj+zKte9b3b6DXXkNdhT4rgq6Q15DwmarLN+Yp6BcXT/bwG438+A4011BJ2p5iQgjRLFXn28lzva8rAJadmqIMWo9ACOEtVUfX5nGI8HEFwKWleCUA7RqypQiEEJ1TdfnGNGPw2BsprsUA3XoOpSmhPs01JGqi84UphKiJsnk5jwt2Ekktiq1bT3gjJdGl51CaEmpzriEpAg9o2nWuTaTQppO877WrvJzUorjjjvQlJLteATBJCY1GLdst4izIvodp8xqalpG/vrvoiXIUea++5eW00fw+eM+1/c0g91F/qeo65wu+FQKiHoq8V1/y8thduok5vep2xW7TtTtJEahrqEPGzW2X4MHbt1GPmspiOinyXn0YwRvtnoqjSndQE11fVWwjdSFF0BFNZtbJ57TVZ+9DISDqp8h77XqRlaUlmJ9PtgsU9UaaxEcX1FqIayb4Hqahayir2VpH87Dt/kfZCKaTou81b1dHE10sadNP19E95UvXV1mQjcAv2shQXfTZayqL6aSNQrtqpSFrvY868n3f7WBSBJ7RRobqe+1F1I8virqJ/J/mIVRXy7Tvrd4kRSAbQUe00ZeqPnu/aXvMhU/jVZpwLEjK1zMz9c0nVHUEtLfEaQffwzS0CJxrvnbW99rLNNPFu6lSC687rzbRIlB+zwZ1DU03SR9qnR9wW90KvnRfNEkXfc1luwqbKGCbKrSHkHeq0IgiAH4XeBR4DZhLOe8igvWNDwE7I/FnAfeH8bcBa/I8V4rgZNqoCaU9o25lM4RaXRf2m7LKpymlpUK7fZpSBG8HfhW4N0kRADPAk8DZwBrgIeDc8NhXgcvC7RuBhTzPlSI4mTZql0nPGI3qLbj77pWRl648usq8KzkdTA9JiqCSsdg597hz7kDGae8CDjnnDjvnfgZ8BdgWLlj/QeD28Ly9BAvYi4K0MaI36V5Hj9Y7wGYoo5O7GHhV1tApp4Pppw2vodOBZyL7R8K4EfCic+7VifhYzGyHmS2b2fLKykpjwvaRNj7UovcqW3APpdDpyvukzHQGXY8WFs2TqQjM7G4zeyQmbGtDwDHOuT3OuTnn3NzGjRvbfLT3tPGhJj1jNIo/v2zBPaRCx4c5ZvIwtS6T4heszjrBOXdhxWc8C5wZ2T8jjDsKvMnMVoetgnG8KMj4g9y1K6iJb94cFJx1fqhJz4D4dWDLFtxt/BdRnO3b9Q6mmVoWrzeze4E/c86dsqK8ma0GngB+k6Cg/x7wcefco2b2NeAfnHNfMbMbgYedczdkPU+L1/vF0pIKbiH6QCOL15vZb5vZEeA9wD+b2Z1h/FvM7A6AsLb/KeBO4HHgq865R8NbfAb4tJkdIrAZ/F0VeUQ39KWLQwgRTy0tgrZRi0AIIYrTSItACCFE/5EiEEKIgSNFIIQQA0eKQAghBk4vjcVmtgIkrPabygbghZrFqQPJVRxfZZNcxfFVtmmUa9Y5d8qI3F4qgrKY2XKcxbxrJFdxfJVNchXHV9mGJJe6hoQQYuBIEQghxMAZmiLY07UACUiu4vgqm+Qqjq+yDUauQdkIhBBCnMrQWgRCCCEmkCIQQoiBM3WKwMx+18weNbPXzCzRxcrMLjKzA2Z2yMx2RuLPMrP7w/jbzGxNTXKdZmZ3mdnB8Hd9zDkfMLMHI+FfzeyS8NgtZvb9yLHz2pIrPO945Nn7IvGNpFde2czsPDP7TvjOHzaz348cqzXNkvJM5PjaMA0OhWmyJXLs6jD+gJl9pIocJeT6tJk9FqbPPWY2GzkW+15bkusTZrYSef4fRY7Nh+/9oJnN1ylXTtmui8j1hJm9GDnWSJqZ2c1m9ryZPZJw3Mzsb0OZHzazd0aOVUuvuIWM+xyAtwO/CtwLzCWcMwM8CZwNrAEeAs4Nj30VuCzcvhFYqEmuvwF2hts7gb/OOP804BiwLty/Bbi0gfTKJRfw04T4RtIrr2zAvwfOCbffAjwHvKnuNEvLM5Fz/hi4Mdy+DLgt3D43PH8tcFZ4n5kW5fpAJB8tjOVKe68tyfUJ4PMx154GHA5/14fb69uUbeL8PwVubiHN3ge8E3gk4fhW4JuAAe8G7q8rvaauReCce9w5dyDjtHcBh5xzh51zPwO+AmwzMwM+CNwenrcXuKQm0baF98t730uBbzrnXso4rypF5foFDadXLtmcc0845w6G2/8CPA80sZZpbJ5Jkfd24DfDNNoGfMU594pz7vvAofB+rcjlnPt2JB/dR7AaYNPkSa8kPgLc5Zw75pz7MXAXcFGHsn0M+HKNz4/FOfe/CCp/SWwD/t4F3EewwuMmakivqVMEOTkdeCayfySMGwEvumAxnWh8HbzZOfdcuP1D4M0Z51/GqZlvd9gkvM7M1rYs1+vNbNnM7ht3V9FsehWRDQAzexdBDe/JSHRdaZaUZ2LPCdPkJwRplOfaJuWKciVBrXJM3HttU67fCd/P7WY2XtK2yfQqdP+wG+0s4FuR6KbSLIskuSunV+aaxT5iZncD/zbm0C7n3DfalmdMmlzRHeecM7NEv91Qy/9HglXdxlxNUBiuIfAj/gxwTYtyzTrnnjWzs4Fvmdl+goKuEjWn2a3AvHPutTC6dJpNI2Z2OTAH/EYk+pT36px7Mv4OtfOPwJedc6+Y2X8maE19sKVn5+Uy4Hbn3PFIXJdp1gi9VATOuQsr3uJZ4MzI/hlh3FGC5tbqsEY3jq8sl5n9yMw2OeeeCwut51Nu9XvA151zP4/ce1wzfsXMvgT8WZtyOeeeDX8PW7BG9fnAP1AhveqSzcx+GfhngorAfZF7l06zGJLyTNw5RyxYq/tXCPJUnmublAszu5BAuf6Gc+6VcXzCe62jUMuUyzl3NLJ7E4FNaHzt+yeuvbcGmXLLFuEy4E+iEQ2mWRZJcldOr6F2DX0POMcCj5c1BC97nwssL98m6J8HmAfqamHsC++X576n9EmGBeG4X/4SINazoAm5zGz9uFvFzDYAvw481nB65ZVtDfB1gr7T2yeO1ZlmsXkmRd5LgW+FabQPuMwCr6KzgHOA71aQpZBcZnY+8EXgYufc85H42PfaolybIrsXE6xpDkFL+MOhfOuBD3Ny67hx2UL53kZgfP1OJK7JNMtiH/AHoffQu4GfhJWd6unVhPW7ywD8NkEf2SvAj4A7w/i3AHdEztsKPEGgyXdF4s8m+EgPAV8D1tYk1wi4BzgI3A2cFsbPATdFzttCoOFXTVz/LWA/QWG2CLyxLbmA94bPfij8vbLp9Cog2+XAz4EHI+G8JtIsLs8QdDVdHG6/PkyDQ2GanB25dld43QHgozXn+Sy57g6/hXH67Mt6ry3J9VfAo+Hzvw28LXLtH4bpeAj4ZJ1y5ZEt3P8scO3EdY2lGUHl77kwPx8hsOdcBVwVHjfgC6HM+4l4RVZNL00xIYQQA2eoXUNCCCFCpAiEEGLgSBEIIcTAkSIQQoiBI0UghBADR4pACCEGjhSBEEIMnP8PNRUZ2fYpuekAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集 100 个数据点的可视化：\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAh60lEQVR4nO3de8wd9X3n8ffHpia1UBvbeAkB/Bi2pAltuiQ86yQbaZsLCYSVMNmSxKmTddJEXtKQlTbaFUZIm4hda2n3D9SqUbsWSXDwswFKFeFuyLJct380pDxI3AwCGwgXlwTHTiJZUK7f/WPmxOPjc33O3OfzkkbnnJk55/zOzJz5/uZ3G0UEZmbWXcuqToCZmVXLgcDMrOMcCMzMOs6BwMys4xwIzMw67riqE7AUJ554Yqxfv77qZJiZNcp99933s4hY2z+/kYFg/fr1LC4uVp0MM7NGkfT0oPkuGjIz6zgHAjOzjnMgMDPrOAcCM7OOcyAwM+u4XAKBpG9JekHSw0OWS9KfS9on6UFJ784s2yJpbzptySM9ZjabhQVYvx6WLUseFxaqTpEVKa8rgmuB80cs/xhwZjptBf4SQNJq4GvAe4ANwNckrcopTWa2BAsLsHUrPP00RCSPW7c6GLRZLoEgIv4OODRilY3AdyJxD/BmSScD5wG3RcShiPg5cBujA0rtOSdlTXfFFfDii0fPe/HFZL61U1kdyk4Bns28fi6dN2z+MSRtJbmaYN26dcWkcka9nFTvT9TLSQFs3lxdusym8cwz08235mtMZXFE7IiI+YiYX7v2mB7StdCUnJSvWqpV9+0/LJ9V0/zXzKrYH3U7Bsq6ItgPnJZ5fWo6bz/wgb75d5eUptw1ISflq5Zq1X37LyzA4cPHzl+5ErZvLz89Ratif9TyGIiIXCZgPfDwkGX/BvgBIOC9wD+k81cDTwGr0ukpYPW47zrnnHOijubmIpLqtaOn5csjdu2qOnWJYWmcm8v3e3btSj5TSh7r8vurVtb2X4pduyJWrjw2bWvWtHf/TbI/8jqWe58z6PvKOgaAxRh0jh40c9oJ+C7wPPAqSTn/F4BLgEvS5QK+ATwBPATMZ977R8C+dPr8JN+XRyAo4kQ17I8Eyfw6/JmkwemT8vuOQduhLr+/amVs/6Wqc5Ca1bD/+7j9kdexPOrcUOYxUGggKHuaNRAUeaLatSu5AqjrH6qMP3ubTyizqvO2qXOQmsWo//u4/ZHX/hp1JdCaK4Kyp1kDQdF/xjr/ocrIrdf591etzldLdQ5Ssxj1u8btj7yO5WGfU/Yx4ECQUfSJqu5/qKLL7+v++6tW1/qTooNUVb97kuKfYekq44qgzG3hQJBR9Imqzrm+MnT99zdZUSfrKo+JWf7vRdYRVPGfcCDIKGOn1DXXV5au/347WpVXibP+3/NuNVTlf2JYIFCyrFnm5+dj1ltVLiwkHb2eeSbpKHPBBXDLLUdeb99ej3bdZm2wbFlyCu4nwRtvFP/9/f/3rv6/Jd0XEfP98xvTszhvmzfDj3+cHITbt8POnfkOslVGz8G69U40G6bq3srZ//uPf9zNIDBKZwNBVt5DQ5QxeqNHiLQm2b496Z2cNai3sjM3FRlUXlT3Ke+exXm3Ilpqeeg0ZYhumWNNM+74rkuFapvhOoLh1q9PctT95uaSy8hpLaU8tH/8EUhyTDt2DL6MrbrM1Sxvef8P7ViuIxhh0svWSS2lPHTa4qmqy1zN8taEQRvbyoGAJMe9Y0eS85CSx2E58UksJbBM+yfIO3iZVc2Zm+o4EKTybFWwlMAy7Z8g7+BlVrWiMjeugJ7AoIqDuk91HYZ6Fq4oM8u/05X/V0djSGWxrwhqwjl8s/zb+1d518AmXYm41ZCZtVZVreumbQVYFrcaMrPOqaoCuin3L+/JJRBIOl/SY5L2Sdo2YPnVku5Pp8cl/SKz7PXMst15pMfMDKprXde0prAz37xe0nKS21B+hOQ2lfdK2h0Rj/TWiYj/mFn/K8C7Mh/xUkScPWs6zMz69Yphyh5wbt26wZ3j6toUNo8rgg3Avoh4MiJeAa4HNo5Y/9Mk9ziutSIreppUiWTWdFUMONe0fj55BIJTgGczr59L5x1D0hxwOnBnZvabJC1KukfSRcO+RNLWdL3FAwcO5JDs4Yoc0M2DxZm1X9NaAc7cakjSxcD5EfHF9PVngfdExKUD1r0MODUivpKZd0pE7Jd0BkmA+HBEPDHqO4tuNVTkmCceT8XMqlJkq6H9wGmZ16em8wbZRF+xUETsTx+fBO7m6PqDUvWKbAadqCGfip6mVSKZWfvlEQjuBc6UdLqkFSQn+2Na/0h6O7AK+GFm3ipJx6fPTwTeDzzS/94yZItshsmjosfjqZhZ3cwcCCLiNeBS4FbgUeDGiNgj6UpJF2ZW3QRcH0eXRb0DWJT0AHAXcFW2tVGZBrX7zcqroqdplUhmVp2yGpbM3HwUICJuAW7pm/df+l5/fcD7/h54Zx5pmNWoopm5uXybnP36rx8JOmvWwJ/9WX0rkcysGv29k3sNSyD/84V7FqeGFc30KnHz2PC9HXvw4JF5L700++eaWfuU2TvZgSBVRpFN07qdm1l1ymxY4kCQKqPdr1sMmdmkymxY4kCQUXQPRLcYMrNJldmwxIGgRG4xZGaTKrN3ci6thmwyVQ2AZWbNtHlzOecHXxGUrIoBsMysHE0dUNJXBGZmOSiz3X/efEVgZpaDJjcPdyAwM8tBk5uHOxCYmeVg9erp5teJA4GZWcc5EJiZ5eDQoenm14kDgZlZDiYdOaCOTUwdCMzMcjDJyAF1vWe5A8GU6hjNzax6kwwJUdcmprkEAknnS3pM0j5J2wYs/5ykA5LuT6cvZpZtkbQ3nbbkkZ6i1DWam1k9jBs5oK5NTGcOBJKWA98APgacBXxa0lkDVr0hIs5Op2vS964Gvga8B9gAfE3SqlnTVJS6RnMza4a6jkCcxxXBBmBfRDwZEa8A1wMbJ3zvecBtEXEoIn4O3Aacn0OaClHXaG5mzVDXEYjzCASnAM9mXj+Xzuv3B5IelHSTpNOmfC+StkpalLR44MCBHJI9vbpGczNrhjKHlp5GWZXFfwusj4jfI8n175z2AyJiR0TMR8T82rVrc0/gJOoazc2sOeo4AnEegWA/cFrm9anpvF+JiIMR8XL68hrgnEnfWyd1jeZmZrPIIxDcC5wp6XRJK4BNwO7sCpJOzry8EHg0fX4r8FFJq9JK4o+m82qrjtHcasjtjK1BZr4fQUS8JulSkhP4cuBbEbFH0pXAYkTsBv6DpAuB14BDwOfS9x6S9F9JggnAlRHRgA7ZZiM0eWB66yRFRNVpmNr8/HwsLi5WnQyzwdavT07+/ebmkstIs4pIui8i5vvnu2exWd7cztgaxoHALG9uZ2wN40Bglje3M7aGcSAwy5vbGVvDOBCYFWFUO2M3LbWambn5qJlNwU1LrYZ8RWA2Tp45eA9hazXkQGA2St43oRjXtNTFRlYBBwKzUfLOwY9qWuo7H1lFHAismcrKOefdOWxU01IXG1lFHAisecrMOefdOWxU01L3SLaKOBBY85SZcy6ic9iwpqXukWwVcSCw5ikz51xE57BhxVrukWwVcT8Ca5516waP7llUznnz5vza+E/Sj+CKK5Kgtm5dEgTcv8AK5isCa4ZsLvrwYVix4ujlTck5jyvW8p2PLKOsNhEOBFZ//ZXDBw8mj2vWHCmu2bIlOZnWvf29K4RtQmW2icglEEg6X9JjkvZJ2jZg+VclPSLpQUl3SJrLLHtd0v3ptLv/vWYDc9GvvgonnJDknLdvh507m9H+vogKYXdCa6VSWxNHxEwTye0pnwDOAFYADwBn9a3zQWBl+vxLwA2ZZYen/c5zzjknrEOkiOQUf/QkJcvn5gYvn5urMtWD7doVsXLl0elcuTKZP+3n9H53//ZZyudZ7Yw77JeC5PbBx5xT87gi2ADsi4gnI+IV4HpgY1+wuSsierHtHuDUHL43V85U1di4XHSTilvyaIWULTOA5PyQ5U5orVBma+I8AsEpwLOZ18+l84b5AvCDzOs3SVqUdI+ki4a9SdLWdL3FAwcOzJTgfu7ZX3PjmlU2rf39rBXCg8oM+k0bBJ0Tqp1SWxMPukyYZgIuBq7JvP4s8BdD1v0MyRXB8Zl5p6SPZwA/Bv75uO/Mu2ioSSULndUrCpGSx2zRR17FLWWkNQ/DygyWevAO2n4QsWaNi5gqlvehxJCioTwCwfuAWzOvLwcuH7DeucCjwD8b8VnXAheP+868A0ERZXFWsqJPvtOko+igNCznstTvG/V5rm+oXJ6HdpGB4DjgSeB0jlQW/07fOu8iqVA+s2/+qt7VAXAisJe+iuZBk68IrLbKOJgGBZtebmYpZ4pxVxj+I1Qm73zFsEAwcx1BRLwGXArcmub4b4yIPZKulHRhutr/AE4A/rqvmeg7gEVJDwB3AVdFxCOzpmla7tlvuSmj4npQhfN11yXniaXUOYyrS6ljpXtHlNWEVEmQaJb5+flYXFzM9TMXFtyz33Kwfv3g4S/m5pKTdB31D3vRr85pb7lly45tFAZJ/H/jjek/T9J9ETF/zPcsJXFt5J79losmXl72rjDWrDl2Wd3T3nJlNYhzIDDryaMJZRGjlZZh82b42c9g167mpb3FyspXuGjIDAYXj6xc6ROhVS7PYuthRUMOBGbQzLJ9sym5jsBslCYNU2GWMwcC67ZevcCwK+O6DlPRz0NE2Ax8hzLrrnHNJpvSYmaSu56ZjeArAuuuUYO3NanFTKkD11sb+YrAumtY+b/UrApi12/YjHxFYN3VtOGrh2nL77DKOBBYdzWxF/Agk/wOVybbCA4E1l159wKu6mQ77nf4zks2hjuUmeWhzj2T3VnOUu5QZlakOrfccWWyjeFAYJaHOp9sXZlsYzgQmOWhzifbKirFXTndKLkEAknnS3pM0j5J2wYsP17SDenyH0lan1l2eTr/MUnn5ZEes9LVuQVS2UNju3K6cWauLJa0HHgc+AjwHHAv8OnsLScl/THwexFxiaRNwMcj4lOSzgK+C2wA3grcDrwtIl4f9Z2uLLZa8m3uEq6crq0iK4s3APsi4smIeAW4HtjYt85GYGf6/Cbgw5KUzr8+Il6OiKeAfennmTWPb3OXqHN9iQ2URyA4BXg28/q5dN7AddKb3f8SWDPhewGQtFXSoqTFAwcO5JBsswZoYll7netLbKDGVBZHxI6ImI+I+bVr11adHLPiNbWsvc71JTZQHoFgP3Ba5vWp6byB60g6DvhN4OCE77Wua2KuOA917pswSlPv29xheVQWH0dSWfxhkpP4vcAfRsSezDpfBt6ZqSz+txHxSUm/A/wvjlQW3wGc6cpi+5U699gt2rJlg2+YIyX1EGZTKqyyOC3zvxS4FXgUuDEi9ki6UtKF6WrfBNZI2gd8FdiWvncPcCPwCPB/gC+PCwLWMU3NFefBZe1WEo81ZPXW5Vxxl6+GrBAea8iaqcu5Ype1W0kcCKzeBrVA+bVfg8OHu1F57L4JVgIHAqu3/lzxmjXJ48GDzWpSaTaBqhrIORBY/WVzxSecAK+8cvTyulQed7WZq+Wiym4jDgTWLHUbvqB38pfgs59tXuevaTjQFarKBnIOBNYsdao8zmbh4NjWTXW5UslDU3s5N8igcfpGzc+TAwHO6DRKnYYvGJSF69eWgda63J+jJMuXTzc/T50PBM7oNMywJpVQfjSf5CTflmaudSuSa6HXh3SlHTY/T50PBM7oNFB/k0qoJpqPO8m3aaC1OhXJtdTc3HTz89T5QOCMTgtUFc0HFVNJyWPbOn/VqUiupS644Mjh01PWJu58IHBGpwWqiuaDiqmuuy65Kmlb5y/3ci7UwgLs3Hl0ewMJtmwpZxN3fqwhD+fSAr41ojVcWYewxxoawhmdFnCxhTVc1UXUnQ8E4OFcGs/R3Bqu6iJqBwJrB0dza7CqL2pnCgSSVku6TdLe9HHVgHXOlvRDSXskPSjpU5ll10p6StL96XT2LOkxK5V7IlpOqr6onamyWNKfAoci4ipJ24BVEXFZ3zpvAyIi9kp6K3Af8I6I+IWka4H/HRE3TfO9vjGNVc6tDKyBiqos3gjsTJ/vBC7qXyEiHo+IvenzfwReANbO+L1m1XJPRGuRWQPBSRHxfPr8J8BJo1aWtAFYATyRmb09LTK6WtLxM6bHrBxVN/Mwy9HYQCDpdkkPD5g2ZteLpIxpaDmTpJOB64DPR0TvZrOXA28H/iWwGrhsyNuRtFXSoqTFAwcOjP9lZkWqupmHWY7GBoKIODcifnfAdDPw0/QE3zvRvzDoMyT9BvB94IqIuCfz2c9H4mXg28CGEenYERHzETG/dq1LlqxiVTfzMMvRrEVDu4Et6fMtwM39K0haAXwP+E5/pXAmiIikfuHhGdNjVo6qm3lYq5XdIG3WVkNrgBuBdcDTwCcj4pCkeeCSiPiipM+Q5Pb3ZN76uYi4X9KdJBXHAu5P33N43Pe61ZCZtVWRDdKGtRrq/FhD4ywsJA1BnnkmKf7dvt2ZPjMrTpHjDnmsoSXwTWtawh2/rEGqaJDmQDCCm4q3gKO5NUwVDdIcCEZwU/EWcDS3hqmiQZoDwQhuKl5zkxT5tCmau4irE6pokOZAMIKbitfYpEU+bYnmLuLqlLIH03UgGMFNxWts0iKftkRzF3FZgdx81Jpp2bKjb/DaIyXZqKw2tAGe5veaDTGs+ehxVSTGbGbr1g1ubD2oyGfz5uad+PtN83vNpuSiIWumthT5TKprv9dK5UBgzdS1Cpyu/V4rlesIctCGImgzaz8PMdEnrybZbtVnZk3XyUCQ58nbrfrMrOk6GQjyPHm3qeOqmXVTJwNBnifvtnRcNbPu6mQgyPPk7VZ9ZtZ0MwUCSasl3SZpb/q4ash6r0u6P512Z+afLulHkvZJuiG9rWXh8jx5u1WfmTXdrFcE24A7IuJM4I709SAvRcTZ6XRhZv6fAFdHxG8BPwe+MGN6JpL3ybvsAaLMzPI06z2LHwM+EBHPpzeivzsifnvAeocj4oS+eQIOAG+JiNckvQ/4ekScN+5769aPwMysCYrqR3BSRDyfPv8JcNKQ9d4kaVHSPZIuSuetAX4REa+lr58DTpkxPWZmNqWxg85Juh14y4BFRzW2jIiQNOzyYi4i9ks6A7hT0kPAL6dJqKStwFaAdW6SY2aWm7GBICLOHbZM0k8lnZwpGnphyGfsTx+flHQ38C7gb4A3SzouvSo4Fdg/Ih07gB2QFA2NS7eZmU1m1qKh3cCW9PkW4Ob+FSStknR8+vxE4P3AI5FUTtwFXDzq/WZmVqxZA8FVwEck7QXOTV8jaV7SNek67wAWJT1AcuK/KiIeSZddBnxV0j6SOoNvzpgeMzOb0kyBICIORsSHI+LMiDg3Ig6l8xcj4ovp87+PiHdGxL9IH7+Zef+TEbEhIn4rIj4RES/P9nOqMckAdr7vuJnVle9QNqPeAHa9sYt6A9jBkf4Ek6xjZlaVzgwxUVSOfJIB7DxCqZnVWSeuCIrMkU8ygJ1HKDWzOuvEFUGROfJJBrDzCKVm7dWG+r9OBIIic+TjBrBbWIDDh499n0coNWu+ttyhsBOBoMgc+agB7HoHycGDR79nzRqPUGrWBm2p/+vEzev76wggyZEXfTJevz7JIfSbm0tGKTWzZlu2LLkS6CcloxHXTadvXl/VPQNcSWzWbm2p/+tEIIBq7hnQloPEzAab9iZXS61YLrpCujOBoAq+jaVZu01T2rDUiuUyKqQ7UUdQpYWFpOLomWeSK4Ht211JbNZFS60zzLOucVgdgQOBmVkJllqxnGeFdKcri9ukDZ1XzLpoqXWGZdQ1OhA0SFs6r5h10VLrDMuoa3QgaJC2dF4x66KlNmMvo/m76wgapGmdV8ysXgqpI5C0WtJtkvamj6sGrPNBSfdnpn+SdFG67FpJT2WWnT1LetrO/RLMrAizFg1tA+6IiDOBO9LXR4mIuyLi7Ig4G/gQ8CLwfzOr/Ofe8oi4f8b0tJr7JZhZEWYNBBuBnenzncBFY9a/GPhBRLw4Zj0boKqhMsys3WYNBCdFxPPp858AJ41ZfxPw3b552yU9KOlqScfPmJ7Wq2KoDDNrt7F3KJN0O/CWAYuOaqsSESFpaM2zpJOBdwK3ZmZfThJAVgA7gMuAK4e8fyuwFWCdC8XNzHIz9oogIs6NiN8dMN0M/DQ9wfdO9C+M+KhPAt+LiFczn/18JF4Gvg1sGJGOHRExHxHza9eunfT3mZktWVc6cM5aNLQb2JI+3wLcPGLdT9NXLJQJIiKpX3h4xvSYmU1s1Im+Sx04Zw0EVwEfkbQXODd9jaR5Sdf0VpK0HjgN+H9971+Q9BDwEHAi8N9mTI+N0JXcjdkkxp3ou9SB0x3KxuiNHvr007B8Obz+etJap2mjiFZ1lzazuho3qmcbO3B60LklyOYYIAkC0MxLxC7lbswmMeoOggsLSSAYpKi2KlVesfuKYIRhOYaeJt17uI25G7NZDPt/r1kDL710bMYJiruKLuuK3VcESzDu3sJNuvewh6cwS/Ry3k8/nWSEsno99wcFgeXLiwsCW7ZUe8XuQDDCUscJryMPT2F2bHFvxJFg0Oupf+jQ4Pe+8cZ0QWCSop5eenrFzv1Ky2xGROOmc845J8qwa1fEypURyeFy9LRyZbK8SXbtipibi5CSx6al32xWc3OD/89zc9OtM86gc8egc8aw71rKd04CWIwB59TKT+pLmcoKBBFHTp4QsXz5kZ1T1UnUJ3OzpZMGn3ClI+tMehIfZdJgMiw9RWU2HQhaII8DtGoOZFalSU/Qsx6nkwScUelZvryY/4YDQQvkccmaVfZJuQ2BzJqtrGNwmoBT5n/CgaAFJs1lTKKKk3LegcxsKcrIAE3z/yozQzYsELgfQYOM6wlZ1WdNyn0ZrEt6oxI880zSwrAOoxG4H0EL5NkEdFSvyqK4L4N1SZPuHeJA0CB53qGsipPyoEAmwQUXFPedZjaeA0HD5JXLGHd1UcS4J5s3Jz0os705I2DnzmaN29QGHonWjjKo4qDuU1cri/M2rJKqyIrkulQYd7kZaxtbb3V5f04DtxoqXlsOxiJP1nm2fFqqNp4Ip1GXYJyXru/PaQwLBC4aykmb7mZUZEVyHSqMuz4k97T7t+7FSG3cn6Vv80HRYdIJ+ASwB3gDmB+x3vnAY8A+YFtm/unAj9L5NwArJvneOl4RtCmXVeRvqUPurQ5XJVWaZv/WYX+N07b9WeQ2p4iiIeAdwG8Ddw8LBMBy4AngDGAF8ABwVrrsRmBT+vyvgC9N8r11DARtOhiL/vNXXYTWpqC9FNPs3yZsqyakcRpF/p5CAsGvPmR0IHgfcGvm9eXpJOBnwHGD1hs11TEQtO1grPpkXaQm5HKLNun+bUIGp237s8htPiwQlFFHcArwbOb1c+m8NcAvIuK1vvmN1Lbx/pvUGWZaefbHaKpJ928d6nTGadv+rGKbjw0Ekm6X9PCAaWNxyRqYjq2SFiUtHjhwoMyvnkjbDsa2a3Ogy1NTMjht2p9VbPPjxq0QEefO+B37gdMyr09N5x0E3izpuPSqoDd/WDp2ADsgGWtoxjQVYvPmZh+AZv16x3Pdxsxpsyq2+dhAkIN7gTMlnU5yot8E/GFSFqa7gIuB64EtwM0lpMfMpuAMTvnK3uYz1RFI+rik50gqer8v6dZ0/lsl3QKQ5vYvBW4FHgVujIg96UdcBnxV0j6SOoNvzpIeMzObnoehNjPrCA9DbWZmAzkQmJl1nAOBmVnHNbKOQNIBYMCNFidyIkmP5rpxuqZT13RBfdPmdE2vrmlbarrmImJt/8xGBoJZSFocVFlSNadrOnVNF9Q3bU7X9OqatrzT5aIhM7OOcyAwM+u4LgaCHVUnYAinazp1TRfUN21O1/TqmrZc09W5OgIzMztaF68IzMwsw4HAzKzjWhkIJH1C0h5Jb0ga2sRK0vmSHpO0T9K2zPzTJf0onX+DpBU5pWu1pNsk7U0fVw1Y54OS7s9M/yTponTZtZKeyiw7u6x0peu9nvnu3Zn5VW6vsyX9MN3fD0r6VGZZrttr2PGSWX58+vv3pdtjfWbZ5en8xySdN0s6lpCur0p6JN0+d0iayywbuE9LTNvnJB3IpOGLmWVb0n2/V9KWktN1dSZNj0v6RWZZYdtM0rckvSDp4SHLJenP03Q/KOndmWVL316DblvW9ImK7qU8Qbr+FNiWPt8G/MmY9VcDh4CV6etrgYsL2F4TpQs4PGR+ZdsLeBtwZvr8rcDzwJvz3l6jjpfMOn8M/FX6fBNwQ/r8rHT944HT089ZXmK6Ppg5hr7US9eofVpi2j4H/MWA964GnkwfV6XPV5WVrr71vwJ8q6Rt9q+BdwMPD1l+AfADklv9vhf4UR7bq5VXBBHxaEQ8Nma1DcC+iHgyIl4huSfCRkkCPgTclK63E7gop6RtTD9v0s+9GPhBRLyY0/cPM226fqXq7RURj0fE3vT5PwIvAMf0nMzBwONlRHpvAj6cbp+NwPUR8XJEPAXsSz+vlHRFxF2ZY+gekptAlWGSbTbMecBtEXEoIn4O3AacX1G6Pg18N6fvHiki/o4k8zfMRuA7kbiH5OZeJzPj9mplIJhQFfdSPikink+f/wQ4acz6mzj2ANyeXhJeLen4ktP1JiW3C72nV1xFjbaXpA0kObwnMrPz2l7DjpeB66Tb45ck22eS9xaZrqwvkOQoewbt07xMmrY/SPfRTZJ6dzOsxTZLi9FOB+7MzC5ym40zLO0zba8y7lBWCEm3A28ZsOiKiKjsTmej0pV9EREhaWjb3TTKv5Pkhj49l5OcEFeQtCO+DLiyxHTNRcR+SWcAd0p6iORkt2Q5b6/rgC0R8UY6e8nbq40kfQaYB34/M/uYfRoRTwz+hEL8LfDdiHhZ0r8nuaL6UInfP84m4KaIeD0zr+ptlrvGBoKoyb2Up0mXpJ9KOjkink9PXC+M+KhPAt+LiFczn93LHb8s6dvAfyozXRGxP318UtLdwLuAv6Hi7SXpN4Dvk2QC7sl89pK31wDDjpdB6zwn6TjgN0mOp0neW2S6kHQuSXD9/Yh4uTd/yD7N66Q2Nm0RcTDz8hqSeqHeez/Q9967y0pXxibgy9kZBW+zcYalfabt1eWioV/dS1lJK5dNwO5Ial5691KGfO+lvDv9vEk+95hyyfRk2CuXvwgY2LKgiHRJWtUrWpF0IvB+4JGqt1e6775HUm56U9+yPLfXwONlRHovBu5Mt89uYJOSVkWnA2cC/zBDWqZKl6R3Af8TuDAiXsjMH7hPc0rXpGk7OfPyQpLb2UJyJfzRNI2rgI9y9NVxoelK0/Z2korXH2bmFb3NxtkN/Lu09dB7gV+mGZ7ZtldRtd9VTsDHScrIXgZ+Ctyazn8rcEtmvQuAx0mi+RWZ+WeQ/FH3AX8NHJ9TutYAdwB7gduB1en8eeCazHrrSSL8sr733wk8RHJC2wWcUFa6gH+VfvcD6eMX6rC9gM8ArwL3Z6azi9heg44XkqKmC9Pnb0p//750e5yRee8V6fseAz6W8/E+Ll23p/+D3vbZPW6flpi2/w7sSdNwF/D2zHv/KN2W+4DPl5mu9PXXgav63lfoNiPJ/D2fHtPPkdTpXAJcki4X8I003Q+RaRU5y/byEBNmZh3X5aIhMzPDgcDMrPMcCMzMOs6BwMys4xwIzMw6zoHAzKzjHAjMzDru/wPTby8hcT1iEgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " 读者不妨自己调节数据集的参数设置来生成属于自己的数据集吧！\n"
     ]
    }
   ],
   "source": [
    "# 数据集参数设置\n",
    "Ntrain = 200        # 规定训练集大小\n",
    "Ntest = 100         # 规定测试集大小\n",
    "boundary_gap = 0.5  # 设置决策边界的宽度\n",
    "seed_data = 2       # 固定随机种子\n",
    "\n",
    "# 生成自己的数据集\n",
    "train_x, train_y, test_x, test_y = circle_data_point_generator(Ntrain, Ntest, boundary_gap, seed_data)\n",
    "\n",
    "# 打印数据集的维度信息\n",
    "print(\"训练集 {} 个数据点的可视化：\".format(Ntrain))\n",
    "data_point_plot(train_x, train_y)\n",
    "print(\"测试集 {} 个数据点的可视化：\".format(Ntest))\n",
    "data_point_plot(test_x, test_y)\n",
    "print(\"\\n 读者不妨自己调节数据集的参数设置来生成属于自己的数据集吧！\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据的预处理\n",
    "\n",
    "与经典机器学习不同的是，量子分类器在实际工作的时候需要考虑数据的预处理。我们需要多加一个步骤将经典的数据转化成量子信息才能放在量子计算机上运行。接下来我们看看具体是怎么完成的。\n",
    "\n",
    "首先我们确定需要使用的量子比特数量。因为我们的数据 $\\{x^{k} = (x^{k}_0, x^{k}_1)\\}$ 是二维的, 按照 Mitarai (2018) 论文[1]中的编码方式我们至少需要2个量子比特。接着准备一系列的初始量子态 $|00\\rangle$。然后将经典信息 $\\{x^{k}\\}$ 编码成一系列量子门 $U(x^{k})$ 并作用在初始量子态上。最终得到一系列的量子态 $|\\psi\\rangle^k = U(x^{k})|00\\rangle$。这样我们就完成从经典信息到量子信息的编码了！给定 $m$ 个量子比特去编码二维的经典数据点，则量子门的构造为：\n",
    "\n",
    "$$\n",
    "U(x^{k}) = \\otimes_{j=0}^{m-1} R_j^z\\big[\\arccos(x_{j \\, \\text{mod} \\, 2}\\cdot x_{j \\, \\text{mod} \\, 2})\\big] R_j^y\\big[\\arcsin(x_{j \\, \\text{mod} \\, 2}) \\big],\\tag{2}\n",
    "$$\n",
    "\n",
    "**注意** ：这种表示下，我们将第一个量子比特编号为 $j = 0$。更多编码方式见 [Robust data encodings for quantum classifiers](https://arxiv.org/pdf/2003.01695.pdf)。读者也可以直接使用量桨中提供的[编码方式](./DataEncoding_CN.ipynb)。这里我们也欢迎读者自己创新尝试全新的编码方式。\n",
    "由于这种编码的方式看着比较复杂，我们不妨来举一个简单的例子。假设我们给定一个数据点 $x = (x_0, x_1)= (1,0)$, 显然这个数据点的标签应该为 1，对应上图**蓝色**的点。同时数据点对应的2比特量子门 $U(x)$ 是\n",
    "\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg( R_0^z\\big[\\arccos(x_{0}\\cdot x_{0})\\big] R_0^y\\big[\\arcsin(x_{0}) \\big]  \\bigg)\n",
    "\\otimes \n",
    "\\bigg( R_1^z\\big[\\arccos(x_{1}\\cdot x_{1})\\big] R_1^y\\big[\\arcsin(x_{1}) \\big] \\bigg),\\tag{3}\n",
    "$$\n",
    "\n",
    "\n",
    "把具体的数值带入我们就能得到：\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg( R_0^z\\big[0\\big] R_0^y\\big[\\pi/2 \\big]  \\bigg)\n",
    "\\otimes \n",
    "\\bigg( R_1^z\\big[\\pi/2\\big] R_1^y\\big[0 \\big] \\bigg),\n",
    "\\tag{4}\n",
    "$$\n",
    "\n",
    "以下是常用的旋转门的矩阵形式：\n",
    "\n",
    "\n",
    "$$\n",
    "R_x(\\theta) := \n",
    "\\begin{bmatrix} \n",
    "\\cos \\frac{\\theta}{2} &-i\\sin \\frac{\\theta}{2} \\\\ \n",
    "-i\\sin \\frac{\\theta}{2} &\\cos \\frac{\\theta}{2} \n",
    "\\end{bmatrix}\n",
    ",\\quad \n",
    "R_y(\\theta) := \n",
    "\\begin{bmatrix}\n",
    "\\cos \\frac{\\theta}{2} &-\\sin \\frac{\\theta}{2} \\\\ \n",
    "\\sin \\frac{\\theta}{2} &\\cos \\frac{\\theta}{2} \n",
    "\\end{bmatrix}\n",
    ",\\quad \n",
    "R_z(\\theta) := \n",
    "\\begin{bmatrix}\n",
    "e^{-i\\frac{\\theta}{2}} & 0 \\\\ \n",
    "0 & e^{i\\frac{\\theta}{2}}\n",
    "\\end{bmatrix}. \\tag{5}\n",
    "$$\n",
    "\n",
    "那么这个两比特量子门 $U(x)$ 的矩阵形式可以写为：\n",
    "\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg(\n",
    "\\begin{bmatrix}\n",
    "1 & 0 \\\\ \n",
    "0 & 1\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "\\cos \\frac{\\pi}{4} &-\\sin \\frac{\\pi}{4} \\\\ \n",
    "\\sin \\frac{\\pi}{4} &\\cos \\frac{\\pi}{4} \n",
    "\\end{bmatrix}\n",
    "\\bigg)\n",
    "\\otimes \n",
    "\\bigg(\n",
    "\\begin{bmatrix}\n",
    "e^{-i\\frac{\\pi}{4}} & 0 \\\\ \n",
    "0 & e^{i\\frac{\\pi}{4}}\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1 &0 \\\\ \n",
    "0 &1\n",
    "\\end{bmatrix}\n",
    "\\bigg),\\tag{6}\n",
    "$$\n",
    "\n",
    "化简后我们作用在零初始化的 $|00\\rangle$ 量子态上可以得到编码后的量子态 $|\\psi\\rangle$，\n",
    "\n",
    "$$\n",
    "|\\psi\\rangle =\n",
    "U(x)|00\\rangle = \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i &0 &-1+i &0 \\\\ \n",
    "0 &1+i &0  &-1-i \\\\\n",
    "1-i &0 &1-i  &0 \\\\\n",
    "0 &1+i &0  &1+i \n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1 \\\\\n",
    "0 \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix}.\\tag{7}\n",
    "$$\n",
    "\n",
    "接着我们来看看代码上怎么实现这种编码方式。需要注意的是：代码中使用了一个张量积来表述\n",
    "\n",
    "$$\n",
    "(U_1 |0\\rangle)\\otimes (U_2 |0\\rangle) = (U_1 \\otimes U_2) |0\\rangle\\otimes|0\\rangle\n",
    "= (U_1 \\otimes U_2) |00\\rangle.\\tag{8}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.589265Z",
     "start_time": "2021-03-02T09:15:06.452691Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "作为测试我们输入以上的经典信息:\n",
      "(x_0, x_1) = (1, 0)\n",
      "编码后输出的2比特量子态为:\n",
      "[[[0.5-0.5j 0. +0.j  0.5-0.5j 0. +0.j ]]]\n"
     ]
    }
   ],
   "source": [
    "def Ry(theta):\n",
    "    \"\"\"\n",
    "    :param theta: 参数\n",
    "    :return: Y 旋转矩阵\n",
    "    \"\"\"\n",
    "    return np.array([[np.cos(theta / 2), -np.sin(theta / 2)],\n",
    "                     [np.sin(theta / 2), np.cos(theta / 2)]])\n",
    "\n",
    "def Rz(theta):\n",
    "    \"\"\"\n",
    "    :param theta: 参数\n",
    "    :return: Z 旋转矩阵\n",
    "    \"\"\"\n",
    "    return np.array([[np.cos(theta / 2) - np.sin(theta / 2) * 1j, 0],\n",
    "                     [0, np.cos(theta / 2) + np.sin(theta / 2) * 1j]])\n",
    "\n",
    "# 经典 -> 量子数据编码器\n",
    "def datapoints_transform_to_state(data, n_qubits):\n",
    "    \"\"\"\n",
    "    :param data: 形状为 [-1, 2]\n",
    "    :param n_qubits: 数据转化后的量子比特数量\n",
    "    :return: 形状为 [-1, 1, 2 ^ n_qubits]\n",
    "    \"\"\"\n",
    "    dim1, dim2 = data.shape\n",
    "    res = []\n",
    "    for sam in range(dim1):\n",
    "        res_state = 1.\n",
    "        zero_state = np.array([[1, 0]])\n",
    "        for i in range(n_qubits):\n",
    "            if i % 2 == 0:\n",
    "                state_tmp=np.dot(zero_state, Ry(np.arcsin(data[sam][0])).T)\n",
    "                state_tmp=np.dot(state_tmp, Rz(np.arccos(data[sam][0] ** 2)).T)\n",
    "                res_state=np.kron(res_state, state_tmp)\n",
    "            elif i % 2 == 1:\n",
    "                state_tmp=np.dot(zero_state, Ry(np.arcsin(data[sam][1])).T)\n",
    "                state_tmp=np.dot(state_tmp, Rz(np.arccos(data[sam][1] ** 2)).T)\n",
    "                res_state=np.kron(res_state, state_tmp)\n",
    "        res.append(res_state)\n",
    "\n",
    "    res = np.array(res)\n",
    "    return res.astype(\"complex128\")\n",
    "\n",
    "print(\"作为测试我们输入以上的经典信息:\")\n",
    "print(\"(x_0, x_1) = (1, 0)\")\n",
    "print(\"编码后输出的2比特量子态为:\")\n",
    "print(datapoints_transform_to_state(np.array([[1, 0]]), n_qubits=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造量子神经网络\n",
    "\n",
    "那么在完成上述从经典数据到量子数据的编码后，我们现在可以把这些量子态输入到量子计算机里面了。在那之前，我们还需要设计下我们所采用的量子神经网络结构。\n",
    "\n",
    "![电路结构](figures/qclassifier-fig-circuit.png \"图 3：参数化量子神经网络的电路结构\")\n",
    "<div style=\"text-align:center\">图 3：参数化量子神经网络的电路结构 </div>\n",
    "\n",
    "\n",
    "为了方便，我们统一将上述参数化的量子神经网络称为 $U(\\boldsymbol{\\theta})$。这个 $U(\\boldsymbol{\\theta})$ 是我们分类器的关键组成部分，需要一定的复杂结构来拟合我们的决策边界。与经典神经网络类似，量子神经网络的的设计并不是唯一的，这里展示的仅仅是一个例子，读者不妨自己设计出自己的量子神经网络。我们还是拿原来提过的这个数据点 $x = (x_0, x_1)= (1,0)$ 来举例子，编码过后我们已经得到了一个量子态 $|\\psi\\rangle$，\n",
    "\n",
    "$$\n",
    "|\\psi\\rangle =\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix},\\tag{9}\n",
    "$$\n",
    "\n",
    "接着我们把这个量子态输入进我们的量子神经网络，也就是把一个酉矩阵乘以一个向量。得到处理过后的量子态 $|\\varphi\\rangle$\n",
    "\n",
    "$$\n",
    "|\\varphi\\rangle = U(\\boldsymbol{\\theta})|\\psi\\rangle,\\tag{10}\n",
    "$$\n",
    "\n",
    "如果我们把所有的参数 $\\theta$ 都设置为 $\\theta = \\pi$, 那么我们就可以写出具体的矩阵了：\n",
    "\n",
    "$$\n",
    "|\\varphi\\rangle = \n",
    "U(\\boldsymbol{\\theta} =\\pi)|\\psi\\rangle =\n",
    "\\begin{bmatrix}\n",
    "0  &0 &-1 &0 \\\\ \n",
    "-1 &0 &0  &0 \\\\\n",
    "0  &1 &0  &0 \\\\\n",
    "0  &0 &0  &1 \n",
    "\\end{bmatrix}\n",
    "\\cdot\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1+i \\\\\n",
    "-1+i \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}.\\tag{11}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.795398Z",
     "start_time": "2021-03-02T09:15:06.782149Z"
    }
   },
   "outputs": [],
   "source": [
    "# 模拟搭建量子神经网络\n",
    "def U_theta(theta, n, depth):  \n",
    "    \"\"\"\n",
    "    :param theta: 维数: [n, depth + 3]\n",
    "    :param n: 量子比特数量\n",
    "    :param depth: 电路深度\n",
    "    :return: U_theta\n",
    "    \"\"\"\n",
    "    # 初始化网络\n",
    "    cir = UAnsatz(n)\n",
    "    \n",
    "    # 先搭建广义的旋转层\n",
    "    for i in range(n):\n",
    "        cir.rz(theta[i][0], i)\n",
    "        cir.ry(theta[i][1], i)\n",
    "        cir.rz(theta[i][2], i)\n",
    "\n",
    "    # 默认深度为 depth = 1\n",
    "    # 搭建纠缠层和 Ry旋转层\n",
    "    for d in range(3, depth + 3):\n",
    "        for i in range(n-1):\n",
    "            cir.cnot([i, i + 1])\n",
    "        cir.cnot([n-1, 0])\n",
    "        for i in range(n):\n",
    "            cir.ry(theta[i][d], i)\n",
    "\n",
    "    return cir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 测量与损失函数\n",
    "\n",
    "当我们在量子计算机上（QPU）用量子神经网络处理过初始量子态 $|\\psi\\rangle$ 后， 我们需要重新测量这个新的量子态 $|\\varphi\\rangle$ 来获取经典信息。这些处理过后的经典信息可以用来计算损失函数 $\\mathcal{L}(\\boldsymbol{\\theta})$。最后我们再通过经典计算机（CPU）来不断更新QNN参数 $\\boldsymbol{\\theta}$ 并优化损失函数。这里我们采用的测量方式是测量泡利 $Z$ 算符在第一个量子比特上的期望值。 具体来说，\n",
    "\n",
    "$$\n",
    "\\langle Z \\rangle = \n",
    "\\langle \\varphi |Z\\otimes I\\cdots \\otimes I| \\varphi\\rangle,\\tag{12}\n",
    "$$\n",
    "\n",
    "复习一下，泡利 $Z$ 算符的矩阵形式为：\n",
    "\n",
    "$$\n",
    "Z := \\begin{bmatrix} 1 &0 \\\\ 0 &-1 \\end{bmatrix},\\tag{13}\n",
    "$$\n",
    "\n",
    "继续我们前面的 2 量子比特的例子，测量过后我们得到的期望值就是：\n",
    "$$\n",
    "\\langle Z \\rangle = \n",
    "\\langle \\varphi |Z\\otimes I| \\varphi\\rangle = \n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1-i \\quad\n",
    "-1-i \\quad\n",
    "0   \\quad\n",
    "0\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1  &0 &0  &0 \\\\ \n",
    "0  &1 &0  &0 \\\\\n",
    "0  &0 &-1 &0 \\\\\n",
    "0  &0 &0  &-1 \n",
    "\\end{bmatrix}\n",
    "\\cdot\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1+i \\\\\n",
    "-1+i \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= 1,\\tag{14}\n",
    "$$\n",
    "\n",
    "好奇的读者或许会问，这个测量结果好像就是我们原来的标签 1 ，这是不是意味着我们已经成功的分类这个数据点了？其实并不然，因为 $\\langle Z \\rangle$ 的取值范围通常在 $[-1,1]$之间。 为了对应我们的标签范围 $y^{k} \\in \\{0,1\\}$, 我们还需要将区间上下限映射上。这个映射最简单的做法就是让\n",
    "\n",
    "$$\n",
    "\\tilde{y}^{k} = \\frac{\\langle Z \\rangle}{2} + \\frac{1}{2} + bias \\quad \\in [0, 1].\\tag{15}\n",
    "$$\n",
    "\n",
    "其中加入偏置（bias）是机器学习中的一个小技巧，目的就是为了让决策边界不受制于原点或者一些超平面。一般我们默认偏置初始化为0，并且优化器在迭代过程中会类似于参数 $\\theta$ 一样不断更新偏置确保 $\\tilde{y}^{k} \\in [0, 1]$。当然读者也可以选择其他复杂的映射（激活函数）比如说 sigmoid 函数。映射过后我们就可以把 $\\tilde{y}^{k}$ 看作是我们估计出的标签（label）了。如果 $\\tilde{y}^{k}< 0.5$ 就对应标签 0，如果 $\\tilde{y}^{k}> 0.5$  就对应标签 1。 我们稍微复习一下整个流程，\n",
    "\n",
    "\n",
    "$$\n",
    "x^{k} \\rightarrow |\\psi\\rangle^{k} \\rightarrow U(\\boldsymbol{\\theta})|\\psi\\rangle^{k} \\rightarrow\n",
    "|\\varphi\\rangle^{k} \\rightarrow ^{k}\\langle \\varphi |Z\\otimes I\\cdots \\otimes I| \\varphi\\rangle^{k}\n",
    "\\rightarrow \\langle Z \\rangle  \\rightarrow \\tilde{y}^{k}.\\tag{16}\n",
    "$$\n",
    "\n",
    "最后我们就可以把损失函数定义为平方损失函数：\n",
    "\n",
    "$$\n",
    "\\mathcal{L} = \\sum_{k} |y^{k} - \\tilde{y}^{k}|^2.\\tag{17}\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:07.667491Z",
     "start_time": "2021-03-02T09:15:07.661325Z"
    }
   },
   "outputs": [],
   "source": [
    "# 生成只作用在第一个量子比特上的泡利 Z 算符\n",
    "# 其余量子比特上都作用单位矩阵\n",
    "def Observable(n):\n",
    "    r\"\"\"\n",
    "    :param n: 量子比特数量\n",
    "    :return: 局部可观测量: Z \\otimes I \\otimes ...\\otimes I\n",
    "    \"\"\"\n",
    "    Ob = pauli_str_to_matrix([[1.0, 'z0']], n)\n",
    "    return Ob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:08.373511Z",
     "start_time": "2021-03-02T09:15:08.358729Z"
    }
   },
   "outputs": [],
   "source": [
    "# 搭建整个优化流程图\n",
    "class Net(paddle.nn.Layer):\n",
    "    \"\"\"\n",
    "    创建模型训练网络\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 n,      # 量子比特数量\n",
    "                 depth,  # 电路深度\n",
    "                 seed_paras=1,\n",
    "                 dtype='float64'):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        self.n = n\n",
    "        self.depth = depth\n",
    "        \n",
    "        # 初始化参数列表 theta，并用 [0, 2*pi] 的均匀分布来填充初始值\n",
    "        self.theta = self.create_parameter(\n",
    "            shape=[n, depth + 3],\n",
    "            default_initializer=paddle.nn.initializer.Uniform(low=0.0, high=2*PI),\n",
    "            dtype=dtype,\n",
    "            is_bias=False)\n",
    "        \n",
    "        # 初始化偏置 (bias)\n",
    "        self.bias = self.create_parameter(\n",
    "            shape=[1],\n",
    "            default_initializer=paddle.nn.initializer.Normal(std=0.01),\n",
    "            dtype=dtype,\n",
    "            is_bias=False)\n",
    "\n",
    "    # 定义前向传播机制、计算损失函数 和交叉验证正确率\n",
    "    def forward(self, state_in, label):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            state_in: The input quantum state, shape [-1, 1, 2^n]\n",
    "            label: label for the input state, shape [-1, 1]\n",
    "        Returns:\n",
    "            The loss:\n",
    "                L = ((<Z> + 1)/2 + bias - label)^2\n",
    "        \"\"\"\n",
    "        # 将 Numpy array 转换成 tensor\n",
    "        Ob = paddle.to_tensor(Observable(self.n))\n",
    "        label_pp = paddle.to_tensor(label)\n",
    "\n",
    "        # 按照随机初始化的参数 theta \n",
    "        cir = U_theta(self.theta, n=self.n, depth=self.depth)\n",
    "        Utheta = cir.U\n",
    "        \n",
    "        # 因为 Utheta是学习到的，我们这里用行向量运算来提速而不会影响训练效果\n",
    "        state_out = matmul(state_in, Utheta)  # 维度 [-1, 1, 2 ** n]\n",
    "        \n",
    "        # 测量得到泡利 Z 算符的期望值 <Z>\n",
    "        E_Z = matmul(matmul(state_out, Ob), transpose(paddle.conj(state_out), perm=[0, 2, 1]))\n",
    "        \n",
    "        # 映射 <Z> 处理成标签的估计值 \n",
    "        state_predict = paddle.real(E_Z)[:, 0] * 0.5 + 0.5 + self.bias\n",
    "        loss = paddle.mean((state_predict - label_pp) ** 2)\n",
    "        \n",
    "        # 计算交叉验证正确率\n",
    "        is_correct = (paddle.abs(state_predict - label_pp) < 0.5).nonzero().shape[0]\n",
    "        acc = is_correct / label.shape[0]\n",
    "\n",
    "        return loss, acc, state_predict.numpy(), cir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练效果与调参\n",
    "\n",
    "好了， 那么定义完以上所有的概念之后我们不妨来看看实际的训练效果！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:08.911819Z",
     "start_time": "2021-03-02T09:15:08.887770Z"
    }
   },
   "outputs": [],
   "source": [
    "def heatmap_plot(net, N):\n",
    "    # 生成数据点 x_y_\n",
    "    Num_points = 30\n",
    "    x_y_ = []\n",
    "    for row_y in np.linspace(0.9, -0.9, Num_points):\n",
    "        row = []\n",
    "        for row_x in np.linspace(-0.9, 0.9, Num_points):\n",
    "            row.append([row_x, row_y])\n",
    "        x_y_.append(row)\n",
    "    x_y_ = np.array(x_y_).reshape(-1, 2).astype(\"float64\")\n",
    "\n",
    "    # 计算预测: heat_data\n",
    "    input_state_test = paddle.to_tensor(\n",
    "        datapoints_transform_to_state(x_y_, N))\n",
    "    loss_useless, acc_useless, state_predict, cir = net(state_in=input_state_test, label=x_y_[:, 0])\n",
    "    heat_data = state_predict.reshape(Num_points, Num_points)\n",
    "\n",
    "    # 画图\n",
    "    fig = plt.figure(1)\n",
    "    ax = fig.add_subplot(111)\n",
    "    x_label = np.linspace(-0.9, 0.9, 3)\n",
    "    y_label = np.linspace(0.9, -0.9, 3)\n",
    "    ax.set_xticks([0, Num_points // 2, Num_points - 1])\n",
    "    ax.set_xticklabels(x_label)\n",
    "    ax.set_yticks([0, Num_points // 2, Num_points - 1])\n",
    "    ax.set_yticklabels(y_label)\n",
    "    im = ax.imshow(heat_data, cmap=plt.cm.RdBu)\n",
    "    plt.colorbar(im)\n",
    "    plt.show()\n",
    "\n",
    "def QClassifier(Ntrain, Ntest, gap, N, D, EPOCH, LR, BATCH, seed_paras, seed_data,):\n",
    "    \"\"\"\n",
    "    量子二分类器\n",
    "    \"\"\"\n",
    "    # 生成数据集\n",
    "    train_x, train_y, test_x, test_y = circle_data_point_generator(Ntrain=Ntrain, Ntest=Ntest, boundary_gap=gap, seed_data=seed_data)\n",
    "\n",
    "    # 读取训练集的维度\n",
    "    N_train = train_x.shape[0]\n",
    "    \n",
    "    paddle.seed(seed_paras)\n",
    "    # 定义优化图\n",
    "    net = Net(n=N, depth=D)\n",
    "\n",
    "    # 一般来说，我们利用Adam优化器来获得相对好的收敛\n",
    "    # 当然你可以改成SGD或者是RMSprop\n",
    "    opt = paddle.optimizer.Adam(learning_rate=LR, parameters=net.parameters())\n",
    "\n",
    "    # 初始化寄存器存储正确率 acc 等信息\n",
    "    summary_iter, summary_test_acc = [], []\n",
    "\n",
    "    # 优化循环\n",
    "    for ep in range(EPOCH):\n",
    "        for itr in range(N_train // BATCH):\n",
    "\n",
    "            # 将经典数据编码成量子态 |psi>, 维度 [-1, 2 ** N]\n",
    "            input_state = paddle.to_tensor(datapoints_transform_to_state(train_x[itr * BATCH:(itr + 1) * BATCH], N))\n",
    "\n",
    "            # 前向传播计算损失函数\n",
    "            loss, train_acc, state_predict_useless, cir \\\n",
    "                = net(state_in=input_state, label=train_y[itr * BATCH:(itr + 1) * BATCH])\n",
    "            if itr % 50 == 0:\n",
    "\n",
    "                # 计算测试集上的正确率 test_acc\n",
    "                input_state_test = paddle.to_tensor(datapoints_transform_to_state(test_x, N))\n",
    "                loss_useless, test_acc, state_predict_useless, t_cir \\\n",
    "                    = net(state_in=input_state_test,label=test_y)\n",
    "                print(\"epoch:\", ep, \"iter:\", itr,\n",
    "                      \"loss: %.4f\" % loss.numpy(),\n",
    "                      \"train acc: %.4f\" % train_acc,\n",
    "                      \"test acc: %.4f\" % test_acc)\n",
    "                # 存储正确率 acc 等信息\n",
    "                summary_iter.append(itr + ep * N_train)\n",
    "                summary_test_acc.append(test_acc) \n",
    "            if (itr + 1) % 151 == 0 and ep == EPOCH - 1:\n",
    "                print(\"训练后的电路：\")\n",
    "                print(cir)\n",
    "\n",
    "            # 反向传播极小化损失函数\n",
    "            loss.backward()\n",
    "            opt.minimize(loss)\n",
    "            opt.clear_grad()\n",
    "\n",
    "    # 画出 heatmap 表示的决策边界\n",
    "    heatmap_plot(net, N=N)\n",
    "\n",
    "    return summary_test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上都是我们定义的函数，下面我们将运行主程序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:50.771171Z",
     "start_time": "2021-03-02T09:15:09.593720Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集的维度大小 x (200, 2) 和 y (200, 1)\n",
      "测试集的维度大小 x (100, 2) 和 y (100, 1) \n",
      "\n",
      "epoch: 0 iter: 0 loss: 0.0318 train acc: 1.0000 test acc: 0.5400\n",
      "epoch: 0 iter: 50 loss: 0.3359 train acc: 0.0000 test acc: 0.8200\n",
      "epoch: 0 iter: 100 loss: 0.0396 train acc: 1.0000 test acc: 0.8700\n",
      "epoch: 0 iter: 150 loss: 0.0952 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 0 loss: 0.1586 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 50 loss: 0.1534 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 100 loss: 0.0624 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 150 loss: 0.0883 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 0 loss: 0.1627 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 50 loss: 0.1378 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 100 loss: 0.0669 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 150 loss: 0.0860 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 0 loss: 0.1658 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 50 loss: 0.1359 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 100 loss: 0.0671 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 150 loss: 0.0849 train acc: 1.0000 test acc: 1.0000\n",
      "训练后的电路：\n",
      "--Rz(0.542)----Ry(3.456)----Rz(2.699)----*--------------x----Ry(6.153)--\n",
      "                                         |              |               \n",
      "--Rz(3.514)----Ry(1.543)----Rz(2.499)----x----*---------|----Ry(3.050)--\n",
      "                                              |         |               \n",
      "--Rz(5.947)----Ry(3.161)----Rz(3.897)---------x----*----|----Ry(1.583)--\n",
      "                                                   |    |               \n",
      "--Rz(0.718)----Ry(5.038)----Rz(1.348)--------------x----*----Ry(0.030)--\n",
      "                                                                        \n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAD5CAYAAABYi5LMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg6klEQVR4nO3de5BkZ3nf8e/TPd1z2dnZXe0FARISBJEgAhGxYlyhKlBWIhb+QGA5sUQlEQmxcgE7sYOrrIoruEQRhEMZkypiI2M5wlU2UHIVJSokimyhsstczGKwiOQAkrCxFqHV3rU7l749+aPPLK3Zfp9z+jLbp6d/n6qunZm3z2W6zzx7znmffh5zd0REdrrKpHdARORSULATkZmgYCciM0HBTkRmgoKdiMwEBTsRmQlzgy5gZoeBjwJV4BPufteW8auAe4CDwEngn7r7U3nrrS6ueG33ob5jlYrF+xSMW7RoMBgu133GUMuGq83ZpkVPGHabOYZddpSEpnDZYDAvjSoajZaNVpu7zU6w3mCsE421m+E2O81G/5+vn8Gba6McDlRWrnBa64We62snHnD3w6Nsb5wGCnZmVgU+Bvwj4Cngq2Z2v7s/1vO0DwOfdPd7zezHgQ8C/yxv3bXdh3jJT32479jCUi1ctj6f/jXm6umT17laNTlWqcYnvdW59Hg4FkTCak5Qj8aHHcsz7LLt4I91lGWjsUarE6632U6Pd4JlW812epsbrXibG+llN9bSQWvtuY302Olnwm2eP/a9vj9f/bNPhssV0t6g9sq3F3pq488+cWD0DY7PoJexPwo87u5PunsD+BRw05bnXAs8lH39hT7jIjLFrFIt9CibQYPdi4G/7vn+qexnvf4c+Ins67cDu81s/3C7JyLlYjMT7Ip4L/AGM/s68AbgKND3XN7MbjezI2Z2pL12dht2RUTGyqY32A06QXEUuLLn+yuyn13g7t8nO7Mzs2XgZnc/3W9l7n43cDfAwqGX60O6IiVnZlRr9UnvxlAGDXZfBa4xs5fSDXK3AO/ofYKZHQBOunsHuIPuzKyI7BBlPGsrYqBg5+4tM3sP8ADd1JN73P1RM7sTOOLu9wNvBD5oZg78EfDuIuuuzc/x4r9xWd+xlxzcFS57aGU+GFtIjq0spH/9pWCmFmA+mK2dD2ZjF4KxWiW+q1CrpmdGo2VzVhuKZo8j7RGq6TTb6WWbwWxsNNsKsBHMuK4GM67nGumx5xrxbOyxs+lZ1adPryXHvndiNTl2+thKuM3jK/0nQRuP3hcuV0h2GTuNBs6zc/fPA5/f8rP/3PP1fcAYXlURKRsDbJT/PSdo4GAnIrNshs7sRGSGzdJlrIjMMDMqMzIbKyIzrHvPTmd2IrLTTfFl7HROq4jIhBiVSrXQI3dNZofN7Ftm9riZ/WKf8ZeY2RfM7Otm9oiZvaVn7I5suW+Z2ZuK7HlpzuyW5qu85qq9fceufVGcV3T5cjrP7sBS+v7CniDPbjHIhwNYmEvnn9WDfLh6UEVkLq/qSTTcSed7WTBGJ85NG61YUyBIX/BKusqNV9LvWTPnV2kEeXgbQW7fWrDi1Vb8+jxzLp1n94Ng7Mkgz+6RxbgKUKpKy7Gc3NFCbDyXsQUrKP0S8Bl3/3Uzu5ZuytvV2de3AK8CXgT8gZm9wt3TCZGUKNiJSPkZRmVuLBMUFyooAZjZZgWl3mDnwOaZzh7g+9nXNwGfcvcN4Ltm9ni2vi9FG1SwE5HixnfPrl8Fpddtec4vA//HzH4G2AX8w55lv7xl2a3Vly6ie3YiMoCBqp4c2KxqlD1uH3BjtwL/w92vAN4C/I6ZDR2zdGYnIsUZWLXwmd1xd78+MZZbQQl4F3AYwN2/ZGYLwIGCy15EZ3YiUpiNr3jnhQpKZlanO+Fw/5bnfA+4AcDMXgksAM9mz7vFzOazCkzXAH+at0Gd2YlIcWO6Z1ewgtJ/BH7TzH6O7mTFO73b4ehRM/sM3cmMFvDuvJlYKFGwm5+rcM0Ldvcdu+ayuMTTC3enU0/2zKdPXpfr6Tet7v07NG2qbJxPjtlGOm3AGumyPtaOt0kj3dXJW8GyrXRjF+/kHCO5qSkJOZUxwj+YuXRqRWV+MTlWn0sfBwCLtfSyPr+cHOss9T8uATaI00BWguPvsiCFJCohthaUnAI4drL/8VfNaSJVVHVuPGGjQAWlx4DXJ5b9APCBQbZXmmAnIuVnZmHr0jJTsBORgdiQBV0nTcFORAaS17S+rBTsRKQ4Q5exIrLzdUs8KdiJyE5nNrZZ3UutNMGuVq1weSKFJEotATi4lE5jWK6kK35Uzh9Pj62dCbdpq+nxzrnTybH2+XQzcF9Lp7MAdNbTKS0epaU002kpntORy4dMPclrylKpBykbQeqJ1dPd4iq74uo4lSCFpLJnf3psuX/Xu+4694XbnF8+mByLOrc12+k0meOr6VQigEcSVYAqYdmc4nRmJyI7npkmKERkRgz/UfzJUrATkYEoz05Edjwzo5pTxbusFOxEZCCaoBCRnc+gosvY0cxVLNkcJ6ocAXF6SfXsM+mx1ZPJsdbxp8Ntdk4dS461o9STc+eSY42z6dQSgOb5dHpJaz3dvKXTTL8+nUbQjIftTD1JH3qVWnqstiuoerKyFG5zfm+QerKSTj2p7r88PXbwReE28fTrt2/5BcmxtVb6NTiwFFdauWy5/9/RXM57UoSSikVkRqjqiYjMginOs5vOaRURmQij+0mMIo/cdeU3yf6ImX0je3zbzE73jLV7xraWc+9LZ3YiUtyYzuyKNMl295/ref7PAK/tWcWau183yDZ1ZiciA7GKFXrkuNAk290bwGaT7JRbgd8bZb8V7ERkANYtzV7gkaNfk+y+ja7N7CrgpcBDPT9eyHrRftnM3lZkz3UZKyKFDVgI4ICZHen5/m53v3uIzd4C3Lelg9hV7n7UzF4GPGRm33T3J6KVlCbYzVUs2W1puZaTs7WWzpcLc+l+8L302Ik4z651Ml0eau1EuozTxunnkmN5eXatMM8uXfanuRbk2TVzO9ANxXJuUEd5eLXFKM8uXe6rvjsnz25fOs9u8WC6vNZ8UD6LnO5s1Wr/nDeAWj3dNW9X0AltX9CVDGBvIg+vOqZZ1AFST0Ztkr3pFuDdvT9w96PZv0+a2cN07+eFwU6XsSJSmBnU5yqFHjmKNMnGzP4WsA/4Us/P9pnZfPb1AbrtFh/buuxWpTmzE5HyM2wsZ4gFm2RDNwh+KmuOvemVwMfNrEP3hO2u3lncFAU7ESnOxnc5nNckO/v+l/ss90Xg1YNuT8FORAozxhfsLjUFOxEpzKw7mTiNFOxEpDAzKzL5UEqlCXYGzCfSFeYr3vfnm6JOYO2oFFMw1ng2PQawdux0euzZU8mxjdPpEk/rZ9JlmgA2zqbHW+vp9JJWkHrinfi17eR0H0up5LTbi9IX5oLUk/mVdOrJ/EqQIgK01qMua8Ol4ETdzgAqu9PdxyqLe5JjS7vSaSm7g/JYAMsL25d60r2MVbATkRmge3YisuPZGGdjLzUFOxEpbFx5dpOgYCciA6mqB4WI7HSbHxebRgp2IlKY8uzGwMyoJ1JPbCOdrgFQaaQrVjSfO50ca585kRxbDyqXQJxesnosnQqzdiqdHrEejAE0zqdTJxrn0lVPorSUTk7qSTseTsqryj1Xq6bHFtKHZfN88HsGKTYA7cZw6SWVWrrKSH0lfQwBdIJjrLLnUHJsYSX9Ai7V068dwGJifBxXn7pnJyIzQ8FORHY8pZ6IyExQIQARmQnT/NnY6dxrEZmYasUKPfKM2Df2NjP7Tva4rch+68xORAob1z27UfrGmtllwPuA6wEHvpYtm06RoETBrmKkU09acTWQzlo69aRzPmp+k05piRrj5C0bpZesHk831dk4m04t6Y6nX4fzzXR1krWgckkzJ7Wk7cPlnuRl2S8E+7u4kU4hWYrSaIas0AJ5TX7SlU02Ti+H663uTx9/1Vb6OLF2+ljIu4qcTzyhMobckzHes7vQNxbAzDb7xqbKq99KN8ABvAl40N1PZss+CBwmp69saYKdiEyB8c3G9usb+7q+m7y4b2zhnrO9FOxEpDDDqBWvZ7edfWMHpmAnIoUZ+Z+O6bFdfWOPAm/csuzDeTuj2VgRKc6gUrFCjxxD942l237xxqx/7D7gxuxnIZ3ZiUhh3TO7yfaNdfeTZvZ+ugET4M7NyYqIgp2IDGQcs7owfN/Y7Of3APcMsj0FOxEpbMB7dqVSqmCXqpMV5RwB+HqQZxeMNc+n85waZ9P5cBB3Aovy4aJcurXTcYmnM0Fu2npQqinKs2tMqMTTWnBPZ72TvpXc9nSenZ2NN1oNykrVl6MyTuljITqGIM4BZWMtOWStKM8u3WEN0sU1x1LiyYy5nM5xZVWqYCci5aczOxHZ8Yzx3bO71BTsRKQ41bMTkVmgMzsRmRm6ZyciO56ZUdNs7Gii/B1rpTtKAXQ20tP/ndUobSA99d9ajVMKoi5Xw46da8Ulis4HKSTRsmtB/kh+6sn2lHiqB/d9hk53WY2Pk6hrWdSdrb2eTgOJxgC8kT6OojFrp/enaumSUwD1RDAaxwlZ9zJ2DCuagNIEOxGZDuP4uNgkKNiJSGGaoBCR2WAwpbfsFOxEpLgBi3eWioKdiBSmy1gRmQ26jB0DC/7H6KQrXQB4UCGi3Qy6UQVj7UZc7r61HqSXrKXXuxGsN0oRyRuPxyZR9SRvvemzgyhpNRpbyMmJ2BV0JmsFY9H72VqPO995kBblUUqVp9+zvGCT+juyMXUXm9YzuymN0SIyKWbFHvnriZtkZ8/5J2b2mJk9ama/2/Pzdk8D7YvKufdTnjM7EZkKlTGkJxdpkm1m1wB3AK9391NmdqhnFWvuft0g21SwE5HCuk2yx7KqIk2yfxr4mLufAnD3Y6NsUJexIlJcwUvYApexRRpdvwJ4hZn9iZl92cwO94wtmNmR7OdvK7LrOrMTkcIMG+QydtQm2XPANXR7xF4B/JGZvdrdTwNXuftRM3sZ8JCZfdPdn8hbmYhIYQNMxo7aJPsp4Cvu3gS+a2bfphv8vuruRwHc/Ukzexh4LTAdwS7sWhRMw0M8hR+ll3Qaw6eetBvpfYqWjVI9mjkVRqJlt2MMtq/hTiRKo6lXRvhdgvdl2LHoGALwIO3Hm0HFlCDdqpLz4m53HtyYqp5caJJNN8jdArxjy3M+C9wK/LaZHaB7Wftk1hh71d03sp+/HviVvA2WJtiJSPld4ibZDwA3mtljQBv4BXc/YWZ/H/i4mXXozjvc1TuLm6JgJyIDGVdOcV6TbHd34OezR+9zvgi8etDtKdiJyECmNYVDwU5ECuumlUznx8UU7ERkICrLLiIzYUpP7BTsRKS4cc3GTsJUBDvLybOjE+fEpXgnvd5OToJZJ8qfCvK9om5deTlt8bLDrTdvmzmvfFru7xKNDfe75O1rJ3hfovcsOk6iPLq8ZaeS6TJWRGbElMY6BTsRKU59Y0VkZij1RER2PJ3ZiciMMM3GisgMKNhfooyG+phbXqMMM5s3s09n418xs6tH3tMhebuTfIjIYMy98KNsBg52PY0y3gxcC9xqZtduedq7gFPu/nLgI8CHRt1RESkJ7xR7lMwwZ3YXGmW4ewPYbJTR6ybg3uzr+4AbbFqncETkecw7hR5lM0ywK9Io48Jz3L0FnAH2D7ODIlIm3v3EUpFHjhH7xt5mZt/JHrcV2fOJTlCY2e3A7QBXXnllzrNFZOLcx3KJOkrfWDO7DHgfcD3dDyZ+LVv2VLTNYc7sijTKuPAcM5sD9gAntq7I3e929+vd/fr9Bw4MsSsicqmN6TK2yO2wVN/YNwEPuvvJbOxB4DA5hgl2FxplmFmdbqOM+7c8535g89TyJ4GHshLLIjLtxjNBMUrf2CLLXmTgy9iCjTJ+C/gdM3scOEk3IA7NLScmV6rJIQtaLVklPZbXwakSrje9bJSQmdeRK142Grv0/8/k/y7R2HCvUd7/3JXgfYnes+g4iY6vvGWn00CXsdvSN3aA5S9a2cAKNMpYB/7xsDslIiXlDBLstqtv7FG6AbB32Yfzdman/bcjItvKsXar0CNHkdthnyULar19Y/lhi8V9WQ/ZG7OfhfRxMREZzBhmY0fpGwtgZu+nGzAB7nT3k3nbVLATkeLcu4+xrGq4vrHZ2D3APYNsT8FORAZTwk9HFKFgJyIDKeNHwYpQsBORAYznExSTUJpg5wSdo3Ly7Gyulhyr1NK/YqWeHsvLn6rW0+PVejrvrx7kc9VyaiVEy0ZjbR++BkNe97GUvDy7YX+XYccgzo2M3rNoLDqGICfPs1ZPjkW5pZ2c1m3bWr3MHTq5M62lVJpgJyLlZ+gyVkRmxZT2wlWwE5EBjC/15FJTsBOR4gb7uFipKNiJyAAc0wSFiMwEndmNyKGTuhdQyZnen0tP4Vej1JNgrLYYb3NuIZ3uEi07fz6dxrCYkzPQDFJI4vSSIOUip/xTe8j7M3m9RaM0kcUgb2UxSOWIlgOYW0i/L3EqUTAWHEMANr+QHgtSpjw45vNSS1J/R2MpKeleqOR6GZUn2InIVHDNxorIzqczOxGZBY6CnYjsfO6ON5uT3o2hKNiJyACm9zJWZdlFpDh3vNMu9MiT1yTbzN5pZs+a2Teyx7/qGWv3/HxrOfe+SnNm50ArMcnjwRQ9xNP7laWl5Fht12JybG4pvc7ushvBWJCWspZOyFxuxgfI8BVI0gs2OpNJPRk2vWRXsNzSfHw4z6+kU5TmV+aTY7Vd6WNhLjiGAKwepJ4EY15NH0PtnAOhkchNGduHvMYwG1ukSXbm0+7+nj6rWHP36wbZZmmCnYhMAy901lbAhSbZAGa22SR7a7AbG13Gikhxm7OxRR6xoo2ubzazR8zsPjPrbb24YGZHsubZbyuy6zqzE5HiBpuNHbVJ9ueA33P3DTP718C9wI9nY1e5+1EzexnwkJl9092fiFamYCciAxhoNnakJtmbbRMznwB+pWfsaPbvk2b2MPBaIAx2uowVkeI2Pxs7+mVsbpNsM3thz7dvBf4i+/k+M5vPvj4AvJ4C9/p0ZiciAxnHZ2MLNsn+WTN7K9ACTgLvzBZ/JfBxM+vQPWG7q88s7kVKE+zcoZlIg/BqOmUAoLKwa6ix+ko6LSUaA1hYXU+OtdbT6SXtxvAzWZWz6XSXejOdkrFQSR+ceekswx7WeZcMw1Y9idJLFval00cgTi+Jxuq7g+MkGAOoLKaPP+bTaSteS6eltNbjd6WRyOEaT4Hh8SUVF2iSfQdwR5/lvgi8etDtlSbYiUj5uTve0sfFRGSnUyEAEZkN0/vZWAU7ESnOwdsKdiKy47n6xorIjNBlrIjseO50NBs7GgcaiaQvr+WU0VlaTo5Vdu9Njs3vPZ0ca5xdDbfZaaRz6TpB+6dKkENWrac7j0Hctax+Ln0A7g7y/lrbVlYqHp+rpX/XqAvYsGWaAJYOpI+jhf27g7E9ybH63pVwm5Xlvckxr6Vz9NqWfg02Wul8S4DVRC5nsnvfINzxvPZmJVWaYCci5eeOgp2IzAJXK0URmQE6sxORWeDutBuaoBCRGaDLWBHZ+TQbOzp3T6eeLKdTSwB8Pj39X9mzPz12/rnk2K7Lhz9Vt6A71tzCuWAsfjsa59P71Aq6lo1Scspzuo+lWFDCCeI0m+h1qC+nu27lpZ5E6SVLh/YFY3uTY9WVy8JtRsdfZz5d/mkt1WoPWM1JF1pLpEWNJfWE6b1np0rFIlKYe3c2tsgjz4h9Y28zs+9kj9uK7HtpzuxEZDpESfNFjdI31swuA94HXE/38whfy5Y9FW1TZ3YiUlzH6TRahR45LvSNdfcGsNk3tog3AQ+6+8kswD0IHM5bSMFORApzGNdl7Ch9Y4su+zwKdiJSXDYbW+RB1je253H7gFv7HHC1u7+G7tnbvaPsuu7ZichABpiN3a6+sUeBN25Z9uG8nSlNsOt4ekp9tRXv5vJiuipFde+h5NhcUKrGKvFJ7+56OgWitivdGaq+kk43mN97Ptxm83zQ0Szodhall+SlnnSGLHsSVXeBvNST9Gs7F762caev+b1B6snBveltHrg8OVY9dEW4TXYfSA55cNyuNtKv+5mN+H7YmdX+x3V7yDSi53HojCep+ELfWLrB6xbgHb1PMLMXuvvT2bcX+sbSbb/4X8xsM1/oRvp0IduqNMFORMrPGU9S8Sh9Y939pJm9n27ABLjT3U/mbVPBTkSKc6fTHM9nY4ftG5uN3QPcM8j2FOxEpDhVPRGR2aDPxorIDHAfzycoJkHBTkQGoErFI2t1OhxfbfQd278U7+biUrpiBZ30NH3V0uklVk+nOADYrnSllaXlE8mxhYPpSivN59IVUQBaUerJev/XDqDTTL8GeQ2Ph/1fvBJUfgGwajr1pLqQbqoTpfXUluL3rLqyNz0WVCep7k+nnrD3BeE2O8sHk2Orlq7Sci6YBDi1Fk8QnE6lngzbPalXBzo56UplVZpgJyLl57guY0VkBvjwNQ4nTcFORAYy7KdqJk3BTkQKU99YEZkN7rjO7ERkx/P84hFlpWAnIoU50NEExWiaHefY+f65Yod2xV2jakEnq7270nlOPpdeb6WeLsUEMLcnvV7W0rl0ndX0WG0tLvHkjXSenW+spceCUlbeyflfetgE0pwSWVZJ59nZXLrEky2m3xdbiN+zSpAbact7k2Od+XRpqM6udH4ewGp1MTl2YjWd//iD59J5k8+c2wi3efy5/uOtsZR40mWsiMwI5dmJyI7XnY3VmZ2I7HRTHOzUcEdEinOn3WwXeuTJa5Ld87ybzczN7Prs+6vNbK2nefZvFNl1ndmJSGHOeD5BUbRJtpntBv498JUtq3jC3a8bZJs6sxOR4gZrpRgp2iT7/cCHgHQqQkGlObNrtDt871T/9InlerybzU46hWStmV52ub6cHFtcSacbAMxX0v+7WWM1PdZKpw1YK34/K0EKSVTKyoIxfEIza0F5La+mSzx5kLLitbjEU7uWTgPxeroz2Yant3muGb9+Z4JyTFF6yROn0sfQk8fiUmCrZ/sfR+OaRR3TPbt+ja5f1/sEM/u7wJXu/j/N7Be2LP9SM/s6cBb4JXf/47wNlibYiUj5dSsVFw52B8zsSM/3d7v73UUWNLMK8KtkHcW2eBp4ibufMLMfAT5rZq9y97PROhXsRKS4bIKioFGaZO8G/jbwsJkBXA7cb2ZvdfcjwEZ3d/xrZvYE8AqgN7BeRMFORIobX+pJ2CTb3c8AFzqMm9nDwHvd/YiZHQROunvbzF4GXAM8mbdBBTsRKcwZT4mngk2yU/4BcKeZNYEO8G/UJFtExmuwe3bxqnKaZG/5+Rt7vv594PcH3Z6CnYgMQIUARrbWaPN/nzrTdyyvWsPx1XTKwYGldBrDnvkgLWU+nW4AsFRLj9er6VSYWiW9r3P1veE2g9VSDSq/BEMEQ9myec/or+PxexaNRm93dCw0c46TVjM9vr6evjRbHaHT1/FEpy+Iq5d85wfp6jj/72g46ciZ4/3TVtqtcVx+5r+3ZVWaYCci5edAQ/XsRGQWtHVmJyI7nQNTestOwU5EinPXmZ2IzAid2YnIjue4zuxGtdFo88Rfneo79syZuBrIvl1BeslSunnL3mBsMafSylI9nXqyGIzVq+lqH7VqnOYRpYFETYeitJQ825V60g5m9KLJvugPrZlzyrHeSn+mcy1oD/jcerpqzJnVdOUSgBPn0uOpxjiQrlwC6dSSTaefPtr3560ghaao7mzsyKuZiNIEOxEpP92zE5GZoXt2IrLjdVNPpjPaKdiJSGHKsxORmeCuj4uJyIyY1svYgbqLWdd/y/o8PpI1xOj3vJ/Kxh81sw+NZ1dFZNKcbrXMIo+yGfTM7s10SyBfQ7cT0K9zcUeg/cB/BX7E3Z81s3vN7AZ3/8NoxY21Df760cf7jtUW0l3AAKr1dNeo6lw6560yl471efllQXMsLCq3FIxFy+WxIfPhov3ZTp0hL4U8OKvwnHVG2+y00mPtoDJvqxF0bgPajf4d8wCa6+kuYc3z6TJOG8/FRXk3zh7vvy/r58PlihlfUrGZHQY+SrdS8Sfc/a7E824G7gP+XtZ/AjO7A3gX0AZ+1t0fyNveoH1jbwI+6V1fBvaa2Qu3POdlwHfc/dns+z8Abh5wOyJSQpsTFEUekZ4m2W8GrgVuNbNr+zzvoibZ2fNuAV4FHAb+e7a+0KDBrl+vxxdvec7jwN80s6vNbA54G8/vIiQiU2oz9aTII8coTbJvAj7l7hvu/l26MedH8zY4aLDL5e6ngH8LfBr4Y+Av6Z5qXsTMbjezI2Z2pLMRN/4VkcnbnI0t8iDrG9vzuL1nVbknTr1NsrfsRpGTrovk3rMzs3cDP519+1XiXo8AuPvngM9ly99OIthlDXPvBqhddtV0TvGIzJgB8uyivrGhnCbZQ8k9s3P3j7n7de5+HfBZ4J9ns7I/Bpxx96f77Oih7N99wL8DPjGuHRaRyRnjZewgTbL/Evgxuk2yry+wbF+DzsZ+HngL3WvkVeBfbA6Y2TeygAjwUTP7O9nXd7r7twfcjoiU0Bg/QTFKk+w14HfN7FeBF9HNDvnTvA1aNJV/KZnZs8BfZd8eAPrPn0tZ6T0rp9735Sp3PzjKyszsf9MThHIcd/fDwbreAvwaP2yS/YFUk+zeYJd9/5+Afwm0gP/g7v8rd9/LEux6mdmRYa/1ZTL0npWT3pcfGvtsrIhIGSnYichMKGuwu3vSOyAD03tWTnpfMqW8ZyciMm5lPbMTERmriQY7lYyaPmZ22My+lb1nv9hnfN7MPp2Nf8XMrp7Abs6UAu/JVWb2h9nf0MNmdsUk9nPSJn1m11sy6na6JaOep6dk1A3u/irgcjO74ZLupQCFK1W8Czjl7i8HPkL3Q9yyTQq+Jx+mW63oNcCdwAcv7V6Ww6SDnUpGTZcilSpuAu7Nvr4PuMGGLbYnRRR5T64FHsq+/kKf8Zkw6WCnklHTpcj7deE57t4CzgD7L8nezaYi78mfAz+Rff12YHd2xTRTJh3scg1SMkpE+nov8AYz+zrwBrqfRZ25v6FL3nBnO0tGybYrUm1i8zlPZWfie4ATl2b3ZlLue+Lu3yc7szOzZeBmdz99qXawLC75mZ1KRk21C5UqzKxOt1LF/Vuecz9wW/b1TwIPuZI5t1Pue2JmB7L6cAB3APdc4n0shUlfxn4eeJLufbnfpBvIgG7JqJ7nfdTMHgP+BLhLJaMmI7sH9x7gAeAvgM+4+6NmdqeZvTV72m8B+83sceDngYtSIWR8Cr4nbwS+ZWbfBl4AfGAiOzth+gSFiMyESZ/ZiYhcEgp2IjITFOxEZCYo2InITFCwE5GZoGAnIjNBwU5EZoKCnYjMhP8PWyDVapqkw5MAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "主程序段总共运行了 23.798545360565186 秒\n"
     ]
    }
   ],
   "source": [
    "def main():\n",
    "    \"\"\"\n",
    "    主函数\n",
    "    \"\"\"\n",
    "    time_start = time.time()\n",
    "    acc = QClassifier(\n",
    "        Ntrain = 200,        # 规定训练集大小\n",
    "        Ntest = 100,         # 规定测试集大小\n",
    "        gap = 0.5,           # 设定决策边界的宽度\n",
    "        N = 4,               # 所需的量子比特数量\n",
    "        D = 1,               # 采用的电路深度\n",
    "        EPOCH = 4,           # 训练 epoch 轮数\n",
    "        LR = 0.01,           # 设置学习速率\n",
    "        BATCH = 1,           # 训练时 batch 的大小\n",
    "        seed_paras = 19,     # 设置随机种子用以初始化各种参数\n",
    "        seed_data = 2,       # 固定生成数据集所需要的随机种子\n",
    "    )\n",
    "    \n",
    "    time_span = time.time() - time_start\n",
    "    print('主程序段总共运行了', time_span, '秒')\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过打印训练结果可以看到不断优化后分类器在测试集和数据集的正确率都达到了 $100\\%$。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Mitarai, Kosuke, et al. Quantum circuit learning. [Physical Review A 98.3 (2018): 032309.](https://arxiv.org/abs/1803.00745)\n",
    "\n",
    "[2] Farhi, Edward, and Hartmut Neven. Classification with quantum neural networks on near term processors. [arXiv preprint arXiv:1802.06002 (2018).](https://arxiv.org/abs/1802.06002)\n",
    "\n",
    "[3] [Schuld, Maria, et al. Circuit-centric quantum classifiers. [Physical Review A 101.3 (2020): 032308.](https://arxiv.org/abs/1804.00633)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
